Compare commits

...

53 Commits

Author SHA1 Message Date
Ruslan Talpa
ee7d8e4512 revert pg-14 submodule changes 2025-07-04 13:55:35 +03:00
Ruslan Talpa
6549708b44 change subzero dep sha 2025-07-04 13:39:49 +03:00
Ruslan Talpa
45631bf2e5 add line to remove from diff 2025-07-04 13:28:52 +03:00
Ruslan Talpa
5dbca8c756 revert changes from original hack branch 2025-07-04 13:27:59 +03:00
Ruslan Talpa
9f46ca5eb1 Merge branch 'main' into ruslan/subzero-integration 2025-07-04 13:03:55 +03:00
Ruslan Talpa
54da030a2d place the entire rest_broker code under a feature flag 2025-07-04 12:46:48 +03:00
Ruslan Talpa
afa4e48071 put subzero dependency under a feature flag 2025-07-04 11:27:36 +03:00
Alex Chi Z.
cc699f6f85 fix(pageserver): do not log no-route-to-host errors (#12468)
## Problem

close https://github.com/neondatabase/neon/issues/12344

## Summary of changes

Add `HostUnreachable` and `NetworkUnreachable` to expected I/O error.
This was new in Rust 1.83.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-03 21:57:42 +00:00
Konstantin Knizhnik
495112ca50 Add GUC for dynamically enable compare local mode (#12424)
## Problem

DEBUG_LOCAL_COMPARE mode allows to detect data corruption.
But it requires rebuild of neon extension (and so requires special
image) and significantly slowdown execution because always fetch pages
from page server.

## Summary of changes

Introduce new GUC `neon.debug_compare_local`, accepting the following
values: " none", "prefetch", "lfc", "all" (by default it is definitely
disabled).
In mode less than "all", neon SMGR will not fetch page from PS if it is
found in local caches.

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-07-03 17:37:05 +00:00
Suhas Thalanki
46158ee63f fix(compute): background installed extensions worker would collect data without waiting for interval (#12465)
## Problem

The background installed extensions worker relied on `interval.tick()`
to go to sleep for a period of time. This can lead to bugs due to the
interval being updated at the end of the loop as the first tick is
[instantaneous](https://docs.rs/tokio/latest/tokio/time/struct.Interval.html#method.tick).

## Summary of changes

Changed it to a `tokio::time::sleep` to prevent this issue. Now it puts
the thread to sleep and only wakes up after the specified duration
2025-07-03 17:10:30 +00:00
Alex Chi Z.
305fe61ac1 fix(pageserver): also print open layer size in backpressure (#12440)
## Problem

Better investigate memory usage during backpressure

## Summary of changes

Print open layer size if backpressure is activated

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-03 16:37:11 +00:00
Vlad Lazar
f95fdf5b44 pageserver: fix duplicate tombstones in ancestor detach (#12460)
## Problem

Ancestor detach from a previously detached parent when there were no
writes panics since it tries to upload the tombstone layer twice.

## Summary of Changes

If we're gonna copy the tombstone from the ancestor, don't bother
creating it.

Fixes https://github.com/neondatabase/neon/issues/12458
2025-07-03 16:35:46 +00:00
Arpad Müller
a852bc5e39 Add new activating scheduling policy for safekeepers (#12441)
When deploying new safekeepers, we don't immediately want to send
traffic to them. Maybe they are not ready yet by the time the deploy
script is registering them with the storage controller.

For pageservers, the storcon solves the problem by not scheduling stuff
to them unless there has been a positive heartbeat response. We can't do
the same for safekeepers though, otherwise a single down safekeeper
would mean we can't create new timelines in smaller regions where there
is only three safekeepers in total.

So far we have created safekeepers as `pause` but this adds a manual
step to safekeeper deployment which is prone to oversight. We want
things to be automatted. So we introduce a new state `activating` that
acts just like `pause`, except that we automatically transition the
policy to `active` once we get a positive heartbeat from the safekeeper.
For `pause`, we always keep the safekeeper paused.
2025-07-03 16:27:43 +00:00
Aleksandr Sarantsev
b96983a31c storcon: Ignore keep-failing reconciles (#12391)
## Problem

Currently, if `storcon` (storage controller) reconciliations repeatedly
fail, the system will indefinitely freeze optimizations. This can result
in optimization starvation for several days until the reconciliation
issues are manually resolved. To mitigate this, we should detect
persistently failing reconciliations and exclude them from influencing
the optimization decision.

## Summary of Changes

- A tenant shard reconciliation is now considered "keep-failing" if it
fails 5 consecutive times. These failures are excluded from the
optimization readiness check.
- Added a new metric: `storage_controller_keep_failing_reconciles` to
monitor such cases.
- Added a warning log message when a reconciliation is marked as
"keep-failing".

---------

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-03 16:21:36 +00:00
Dmitrii Kovalkov
3ed28661b1 storcon: remote feature testing safekeeper quorum checks (#12459)
## Problem
Previous PR didn't fix the creation of timeline in neon_local with <3
safekeepers because there is one more check down the stack.

- Closes: https://github.com/neondatabase/neon/issues/12298
- Follow up on https://github.com/neondatabase/neon/pull/12378

## Summary of changes
- Remove feature `testing` safekeeper quorum checks from storcon

---------

Co-authored-by: Arpad Müller <arpad-m@users.noreply.github.com>
2025-07-03 15:02:30 +00:00
Conrad Ludgate
03e604e432 Nightly lints and small tweaks (#12456)
Let chains available in 1.88 :D new clippy lints coming up in future
releases.
2025-07-03 14:47:12 +00:00
HaoyuHuang
4db934407a SK changes #1 (#12448)
## TLDR
This PR is a no-op. The changes are disabled by default. 

## Problem
I. Currently we don't have a way to detect disk I/O failures from WAL
operations.

II.
We observe that the offloader fails to upload a segment due to race
conditions on XLOG SWITCH and PG start streaming WALs. wal_backup task
continously failing to upload a full segment while the segment remains
partial on the disk.

The consequence is that commit_lsn for all SKs move forward but
backup_lsn stays the same. Then, all SKs run out of disk space.

III.
We have discovered SK bugs where the WAL offload owner cannot keep up
with WAL backup/upload to S3, which results in an unbounded accumulation
of WAL segment files on the Safekeeper's disk until the disk becomes
full. This is a somewhat dangerous operation that is hard to recover
from because the Safekeeper cannot write its control files when it is
out of disk space. There are actually 2 problems here:

1. A single problematic timeline can take over the entire disk for the
SK
2. Once out of disk, it's difficult to recover SK


IV. 
Neon reports certain storage errors as "critical" errors using a marco,
which will increment a counter/metric that can be used to raise alerts.
However, this metric isn't sliced by tenant and/or timeline today. We
need the tenant/timeline dimension to better respond to incidents and
for blast radius analysis.

## Summary of changes
I. 
The PR adds a `safekeeper_wal_disk_io_errors ` which is incremented when
SK fails to create or flush WALs.

II. 
To mitigate this issue, we will re-elect a new offloader if the current
offloader is lagging behind too much.
Each SK makes the decision locally but they are aware of each other's
commit and backup lsns.

The new algorithm is
- determine_offloader will pick a SK. say SK-1.
- Each SK checks
-- if commit_lsn - back_lsn > threshold,
-- -- remove SK-1 from the candidate and call determine_offloader again.

SK-1 will step down and all SKs will elect the same leader again.
After the backup is caught up, the leader will become SK-1 again.

This also helps when SK-1 is slow to backup. 

I'll set the reelect backup lag to 4 GB later. Setting to 128 MB in dev
to trigger the code more frequently.

III. 
This change addresses problem no. 1 by having the Safekeeper perform a
timeline disk utilization check check when processing WAL proposal
messages from Postgres/compute. The Safekeeper now rejects the WAL
proposal message, effectively stops writing more WAL for the timeline to
disk, if the existing WAL files for the timeline on the SK disk exceeds
a certain size (the default threshold is 100GB). The disk utilization is
calculated based on a `last_removed_segno` variable tracked by the
background task removing WAL files, which produces an accurate and
conservative estimate (>= than actual disk usage) of the actual disk
usage.


IV.
* Add a new metric `hadron_critical_storage_event_count` that has the
`tenant_shard_id` and `timeline_id` as dimensions.
* Modified the `crtitical!` marco to include tenant_id and timeline_id
as additional arguments and adapted existing call sites to populate the
tenant shard and timeline ID fields. The `critical!` marco invocation
now increments the `hadron_critical_storage_event_count` with the extra
dimensions. (In SK there isn't the notion of a tenant-shard, so just the
tenant ID is recorded in lieu of tenant shard ID.)

I considered adding a separate marco to avoid merge conflicts, but I
think in this case (detecting critical errors) conflicts are probably
more desirable so that we can be aware whenever Neon adds another
`critical!` invocation in their code.

---------

Co-authored-by: Chen Luo <chen.luo@databricks.com>
Co-authored-by: Haoyu Huang <haoyu.huang@databricks.com>
Co-authored-by: William Huang <william.huang@databricks.com>
2025-07-03 14:32:53 +00:00
Ruslan Talpa
b54872a4dc fix error after merging latest master 2025-07-03 17:24:13 +03:00
Ruslan Talpa
486829f875 Merge branch 'main' into ruslan/subzero-integration 2025-07-03 17:10:43 +03:00
Ruslan Talpa
95e1011cd6 subzero pre-integration refactor (#12416)
## Problem
integrating subzero requires a bit of refactoring. To make the
integration PR a bit more manageable, the refactoring is done in this
separate PR.
 
## Summary of changes
* move common types/functions used in sql_over_http to errors.rs and
http_util.rs
* add the "Local" auth backend to proxy (similar to local_proxy), useful
in local testing
* change the Connect and Send type for the http client to allow for
custom body when making post requests to local_proxy from the proxy

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
2025-07-03 11:04:08 +00:00
Conrad Ludgate
1bc1eae5e8 fix redis credentials check (#12455)
## Problem

`keep_connection` does not exit, so it was never setting
`credentials_refreshed`.

## Summary of changes

Set `credentials_refreshed` to true when we first establish a
connection, and after we re-authenticate the connection.
2025-07-03 09:51:35 +00:00
Matthias van de Meent
e12d4f356a Work around Clap's incorrect usage of Display for default_value_t (#12454)
## Problem

#12450 

## Summary of changes

Instead of `#[arg(default_value_t = typed_default_value)]`, we use
`#[arg(default_value = "str that deserializes into the value")]`,
because apparently you can't convince clap to _not_ deserialize from the
Display implementation of an imported enum.
2025-07-03 09:41:09 +00:00
Ruslan Talpa
4775aa3e01 Merge branch 'main' into ruslan/subzero-integration 2025-07-01 13:46:57 +03:00
Ruslan Talpa
1785f856b6 Move the local auth backend under the "testing" feature 2025-06-30 16:52:45 +03:00
Ruslan Talpa
69b22b05da add in readme the way to run auth/rest broker locally 2025-06-30 16:17:35 +03:00
Ruslan Talpa
bf0007fa96 add note about local confir read code 2025-06-30 16:12:04 +03:00
Ruslan Talpa
a9bbe7b00b remove unused imports 2025-06-30 16:02:30 +03:00
Ruslan Talpa
7e3f64b309 implement local auth backend for proxy and remove control plane hacks 2025-06-30 16:00:43 +03:00
Ruslan Talpa
9480d17de7 fix bug in pickcurrent_chema 2025-06-30 12:53:47 +03:00
Ruslan Talpa
424004ec95 apply cargo fmt 2025-06-30 12:32:47 +03:00
Ruslan Talpa
88d1a78260 cleanup the rest path code 2025-06-30 12:30:33 +03:00
Ruslan Talpa
8e544c7f99 import introspection queries instead of loading from files 2025-06-27 14:40:02 +03:00
Ruslan Talpa
4f49fc5b79 move common error types and http realted functions to error.rs and http_util.rs 2025-06-27 13:37:29 +03:00
Ruslan Talpa
5461039c3f implement remote config fetch from the db and cache introspected schema 2025-06-27 10:20:26 +03:00
Ruslan Talpa
d6c36d103e subzero integration WIP6
beginning work on introspection of config and schema shape from the database
2025-06-26 14:42:31 +03:00
Ruslan Talpa
fbb2416685 pg 14 vendor commit changed 2025-06-26 10:25:54 +03:00
Ruslan Talpa
8072fae2fe Merge branch 'main' into ruslan/subzero-integration 2025-06-26 10:19:16 +03:00
Ruslan Talpa
3869d680f9 use a global parsed/cached schema 2025-06-26 10:14:28 +03:00
Ruslan Talpa
d3fa228d92 move subzero local test files to a "dot" folder 2025-06-25 16:43:50 +03:00
Ruslan Talpa
be6a259b85 subzero integration WIP5
cleanup and postprocess the response and set the correct headers/status and handle errors
2025-06-25 14:33:45 +03:00
Ruslan Talpa
af3ca24a5e remove unused enum values 2025-06-25 14:32:46 +03:00
Ruslan Talpa
8b44f5b479 subzero integration WIP5
extract the response body from the local proxy response
2025-06-24 17:03:42 +03:00
Ruslan Talpa
d1445cf3eb subzero integration WIP4
queries generated by subzero reach database and execute succesfully
2025-06-24 15:33:51 +03:00
Ruslan Talpa
67d3026fc4 subzero integration WIP3
* query makes it to the database
2025-06-23 11:48:55 +03:00
Ruslan Talpa
09e62e9b98 subzero integration WIP2 2025-06-23 10:11:06 +03:00
Ruslan Talpa
e121da4bfc subzero integration WIP1 2025-06-20 15:10:45 +03:00
Ruslan Talpa
4a948c9781 add note about ICU lib missing on macs and the fix 2025-06-20 10:58:38 +03:00
Ruslan Talpa
b39f04ab99 add missing parts to make disable_pg_session_jwt flag work 2025-06-20 10:23:42 +03:00
Ruslan Talpa
6bd15908fb make pg_session_jwt instalation optional with a cli flag 2025-06-20 10:17:32 +03:00
Ruslan Talpa
3e36d516c2 vanilla pg dokcer image setup 2025-06-20 09:37:39 +03:00
Conrad Ludgate
cc3af6f7dd code for local setup of auth-broker 2025-06-20 09:37:39 +03:00
Conrad Ludgate
5badc7a3fb code for local setup of auth-broker 2025-06-19 10:34:09 +01:00
Conrad Ludgate
3a73644308 use cargo-chef for compute-tools 2025-06-19 09:24:53 +01:00
75 changed files with 5094 additions and 2276 deletions

5
.gitignore vendored
View File

@@ -25,6 +25,11 @@ compaction-suite-results.*
*.o
*.so
*.Po
*.pid
# pgindent typedef lists
*.list
# various files for local testing
/proxy/.subzero
local_proxy.json

3241
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -2371,24 +2371,23 @@ LIMIT 100",
installed_extensions_collection_interval
);
let handle = tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
loop {
interval.tick().await;
info!(
"[NEON_EXT_INT_SLEEP]: Interval: {}",
installed_extensions_collection_interval
);
// Sleep at the start of the loop to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let _ = installed_extensions(conf.clone()).await;
// Acquire a read lock on the compute spec and then update the interval if necessary
interval = tokio::time::interval(tokio::time::Duration::from_secs(std::cmp::max(
installed_extensions_collection_interval = std::cmp::max(
installed_extensions_collection_interval,
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst),
)));
installed_extensions_collection_interval = interval.period().as_secs();
);
}
});

View File

@@ -64,7 +64,9 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1);
const DEFAULT_BRANCH_NAME: &str = "main";
project_git_version!(GIT_VERSION);
#[allow(dead_code)]
const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
const DEFAULT_PG_VERSION_NUM: &str = "17";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";
@@ -167,7 +169,7 @@ struct TenantCreateCmdArgs {
#[clap(short = 'c')]
config: Vec<String>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version to use for the initial timeline")]
pg_version: PgMajorVersion,
@@ -290,7 +292,7 @@ struct TimelineCreateCmdArgs {
#[clap(long, help = "Human-readable alias for the new timeline")]
branch_name: String,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
}
@@ -322,7 +324,7 @@ struct TimelineImportCmdArgs {
#[clap(long, help = "Lsn the basebackup ends at")]
end_lsn: Option<Lsn>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version of the backup being imported")]
pg_version: PgMajorVersion,
}
@@ -601,7 +603,7 @@ struct EndpointCreateCmdArgs {
)]
config_only: bool,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,

View File

@@ -420,6 +420,7 @@ impl From<NodeSchedulingPolicy> for String {
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum SkSchedulingPolicy {
Active,
Activating,
Pause,
Decomissioned,
}
@@ -430,6 +431,7 @@ impl FromStr for SkSchedulingPolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"active" => Self::Active,
"activating" => Self::Activating,
"pause" => Self::Pause,
"decomissioned" => Self::Decomissioned,
_ => {
@@ -446,6 +448,7 @@ impl From<SkSchedulingPolicy> for String {
use SkSchedulingPolicy::*;
match value {
Active => "active",
Activating => "activating",
Pause => "pause",
Decomissioned => "decomissioned",
}

View File

@@ -78,7 +78,13 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
HostUnreachable
| NetworkUnreachable
| BrokenPipe
| ConnectionRefused
| ConnectionAborted
| ConnectionReset
| TimedOut,
)
}

View File

@@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
if i.is_multiple_of(1024) {
yield_now().await
}
}

View File

@@ -90,7 +90,7 @@ pub struct InnerClient {
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
}
@@ -227,7 +227,7 @@ impl Client {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
@@ -262,7 +262,7 @@ impl Client {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
) -> Result<SimpleQueryStream<'_>, Error> {
simple_query::simple_query(self.inner_mut(), query).await
}

View File

@@ -12,7 +12,11 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -22,7 +26,11 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -35,7 +43,11 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -47,7 +47,7 @@ impl<'a> Transaction<'a> {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -24,12 +24,28 @@ macro_rules! critical {
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: {}\n{backtrace}", format!($($arg)*));
}};
}
#[macro_export]
macro_rules! critical_timeline {
($tenant_shard_id:expr, $timeline_id:expr, $($arg:tt)*) => {{
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
$crate::logging::HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC.inc(&$tenant_shard_id.to_string(), &$timeline_id.to_string());
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: [tenant_shard_id: {}, timeline_id: {}] {}\n{backtrace}",
$tenant_shard_id, $timeline_id, format!($($arg)*));
}};
}
#[derive(EnumString, strum_macros::Display, VariantNames, Eq, PartialEq, Debug, Clone, Copy)]
#[strum(serialize_all = "snake_case")]
pub enum LogFormat {
@@ -61,6 +77,36 @@ pub struct TracingEventCountMetric {
trace: IntCounter,
}
// Begin Hadron: Add a HadronCriticalStorageEventCountMetric metric that is sliced by tenant_id and timeline_id
pub struct HadronCriticalStorageEventCountMetric {
critical: IntCounterVec,
}
pub static HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC: Lazy<HadronCriticalStorageEventCountMetric> =
Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"hadron_critical_storage_event_count",
"Number of critical storage events, by tenant_id and timeline_id",
&["tenant_shard_id", "timeline_id"]
)
.expect("failed to define metric");
HadronCriticalStorageEventCountMetric::new(vec)
});
impl HadronCriticalStorageEventCountMetric {
fn new(vec: IntCounterVec) -> Self {
Self { critical: vec }
}
// Allow public access from `critical!` macro.
pub fn inc(&self, tenant_shard_id: &str, timeline_id: &str) {
self.critical
.with_label_values(&[tenant_shard_id, timeline_id])
.inc();
}
}
// End Hadron
pub static TRACING_EVENT_COUNT_METRIC: Lazy<TracingEventCountMetric> = Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"libmetrics_tracing_event_count",

View File

@@ -78,7 +78,7 @@ use utils::rate_limit::RateLimit;
use utils::seqwait::SeqWait;
use utils::simple_rcu::{Rcu, RcuReadGuard};
use utils::sync::gate::{Gate, GateGuard};
use utils::{completion, critical, fs_ext, pausable_failpoint};
use utils::{completion, critical_timeline, fs_ext, pausable_failpoint};
#[cfg(test)]
use wal_decoder::models::value::Value;
use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
@@ -4729,7 +4729,7 @@ impl Timeline {
}
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size) = {
let (layer, l0_count, frozen_count, frozen_size, open_layer_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
@@ -4742,8 +4742,13 @@ impl Timeline {
.iter()
.map(|l| l.estimated_in_mem_size())
.sum();
let open_layer_size: u64 = lm
.open_layer
.as_ref()
.map(|l| l.estimated_in_mem_size())
.unwrap_or(0);
let layer = lm.frozen_layers.front().cloned();
(layer, l0_count, frozen_count, frozen_size)
(layer, l0_count, frozen_count, frozen_size, open_layer_size)
// drop 'layers' lock
};
let Some(layer) = layer else {
@@ -4756,7 +4761,7 @@ impl Timeline {
if l0_count >= stall_threshold {
warn!(
"stalling layer flushes for compaction backpressure at {l0_count} \
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let stall_timer = self
.metrics
@@ -4809,7 +4814,7 @@ impl Timeline {
let delay = flush_duration.as_secs_f64();
info!(
"delaying layer flush by {delay:.3}s for compaction backpressure at \
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let _delay_timer = self
.metrics
@@ -6819,7 +6824,11 @@ impl Timeline {
Err(walredo::Error::Cancelled) => return Err(PageReconstructError::Cancelled),
Err(walredo::Error::Other(err)) => {
if fire_critical_error {
critical!("walredo failure during page reconstruction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"walredo failure during page reconstruction: {err:?}"
);
}
return Err(PageReconstructError::WalRedo(
err.context("reconstruct a page image"),

View File

@@ -36,7 +36,7 @@ use serde::Serialize;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::TimelineId;
use utils::lsn::Lsn;
use wal_decoder::models::record::NeonWalRecord;
@@ -1390,7 +1390,11 @@ impl Timeline {
GetVectoredError::MissingKey(_),
) = err
{
critical!("missing key during compaction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"missing key during compaction: {err:?}"
);
}
})?;
@@ -1418,7 +1422,11 @@ impl Timeline {
// Alert on critical errors that indicate data corruption.
Err(err) if err.is_critical() => {
critical!("could not compact, repartitioning keyspace failed: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"could not compact, repartitioning keyspace failed: {err:?}"
);
}
// Log other errors. No partitioning? This is normal, if the timeline was just created

View File

@@ -182,6 +182,7 @@ pub(crate) async fn generate_tombstone_image_layer(
detached: &Arc<Timeline>,
ancestor: &Arc<Timeline>,
ancestor_lsn: Lsn,
historic_layers_to_copy: &Vec<Layer>,
ctx: &RequestContext,
) -> Result<Option<ResidentLayer>, Error> {
tracing::info!(
@@ -199,6 +200,20 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
for layer in historic_layers_to_copy {
let desc = layer.layer_desc();
if !desc.is_delta
&& desc.lsn_range.start == image_lsn
&& overlaps_with(&key_range, &desc.key_range)
{
tracing::info!(
layer=%layer, "will copy tombstone from ancestor instead of creating a new one"
);
return Ok(None);
}
}
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
@@ -450,7 +465,8 @@ pub(super) async fn prepare(
Vec::with_capacity(straddling_branchpoint.len() + rest_of_historic.len() + 1);
if let Some(tombstone_layer) =
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, ctx).await?
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, &rest_of_historic, ctx)
.await?
{
new_layers.push(tombstone_layer.into());
}

View File

@@ -25,7 +25,7 @@ use tokio_postgres::replication::ReplicationStream;
use tokio_postgres::{Client, SimpleQueryMessage, SimpleQueryRow};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::NodeId;
use utils::lsn::Lsn;
use utils::pageserver_feedback::PageserverFeedback;
@@ -368,9 +368,13 @@ pub(super) async fn handle_walreceiver_connection(
match raw_wal_start_lsn.cmp(&expected_wal_start) {
std::cmp::Ordering::Greater => {
let msg = format!(
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn})"
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn}"
);
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
critical!("{msg}");
return Err(WalReceiverError::Other(anyhow!(msg)));
}
std::cmp::Ordering::Less => {
@@ -383,7 +387,11 @@ pub(super) async fn handle_walreceiver_connection(
"Received record with next_record_lsn multiple times ({} < {})",
first_rec.next_record_lsn, expected_wal_start
);
critical!("{msg}");
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
return Err(WalReceiverError::Other(anyhow!(msg)));
}
}
@@ -452,7 +460,11 @@ pub(super) async fn handle_walreceiver_connection(
// TODO: we can't differentiate cancellation errors with
// anyhow::Error, so just ignore it if we're cancelled.
if !cancellation.is_cancelled() && !timeline.is_stopping() {
critical!("{err:?}")
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{err:?}"
);
}
})?;

View File

@@ -40,7 +40,7 @@ use tracing::*;
use utils::bin_ser::{DeserializeError, SerializeError};
use utils::lsn::Lsn;
use utils::rate_limit::RateLimit;
use utils::{critical, failpoint_support};
use utils::{critical_timeline, failpoint_support};
use wal_decoder::models::record::NeonWalRecord;
use wal_decoder::models::*;
@@ -418,18 +418,30 @@ impl WalIngest {
// as there has historically been cases where PostgreSQL has cleared spurious VM pages. See:
// https://github.com/neondatabase/neon/pull/10634.
let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else {
critical!("clear_vm_bits for unknown VM relation {vm_rel}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"clear_vm_bits for unknown VM relation {vm_rel}"
);
return Ok(());
};
if let Some(blknum) = new_vm_blk {
if blknum >= vm_size {
critical!("new_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"new_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
new_vm_blk = None;
}
}
if let Some(blknum) = old_vm_blk {
if blknum >= vm_size {
critical!("old_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"old_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
old_vm_blk = None;
}
}

View File

@@ -87,6 +87,14 @@ static const struct config_enum_entry running_xacts_overflow_policies[] = {
{NULL, 0, false}
};
static const struct config_enum_entry debug_compare_local_modes[] = {
{"none", DEBUG_COMPARE_LOCAL_NONE, false},
{"prefetch", DEBUG_COMPARE_LOCAL_PREFETCH, false},
{"lfc", DEBUG_COMPARE_LOCAL_LFC, false},
{"all", DEBUG_COMPARE_LOCAL_ALL, false},
{NULL, 0, false}
};
/*
* XXX: These private to procarray.c, but we need them here.
*/
@@ -519,6 +527,16 @@ _PG_init(void)
GUC_UNIT_KB,
NULL, NULL, NULL);
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
debug_compare_local_modes,
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the

View File

@@ -177,6 +177,22 @@ extern StringInfoData nm_pack_request(NeonRequest *msg);
extern NeonResponse *nm_unpack_response(StringInfo s);
extern char *nm_to_string(NeonMessage *msg);
/*
* If debug_compare_local>DEBUG_COMPARE_LOCAL_NONE, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
typedef enum
{
DEBUG_COMPARE_LOCAL_NONE, /* normal mode - pages are storted locally only for unlogged relations */
DEBUG_COMPARE_LOCAL_PREFETCH, /* if page is found in prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_LFC, /* if page is found in LFC or prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_ALL /* always fetch page from PS and compare it with local */
} DebugCompareLocalMode;
extern int debug_compare_local;
/*
* API
*/

View File

@@ -76,21 +76,11 @@
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
/*
* If DEBUG_COMPARE_LOCAL is defined, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
/* #define DEBUG_COMPARE_LOCAL */
#ifdef DEBUG_COMPARE_LOCAL
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"
static char *hexdump_page(char *page);
#endif
#define IS_LOCAL_REL(reln) (\
NInfoGetDbOid(InfoFromSMgrRel(reln)) != 0 && \
@@ -108,6 +98,8 @@ typedef enum
UNLOGGED_BUILD_NOT_PERMANENT
} UnloggedBuildPhase;
int debug_compare_local;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -478,9 +470,10 @@ neon_init(void)
old_redo_read_buffer_filter = redo_read_buffer_filter;
redo_read_buffer_filter = neon_redo_read_buffer_filter;
#ifdef DEBUG_COMPARE_LOCAL
mdinit();
#endif
if (debug_compare_local)
{
mdinit();
}
}
/*
@@ -803,13 +796,16 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
case RELPERSISTENCE_TEMP:
case RELPERSISTENCE_UNLOGGED:
#ifdef DEBUG_COMPARE_LOCAL
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
#else
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
}
else
{
mdcreate(reln, forkNum, isRedo);
}
return;
default:
@@ -848,10 +844,11 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
else
set_cached_relsize(InfoFromSMgrRel(reln), forkNum, 0);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
}
}
/*
@@ -877,7 +874,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
* Might or might not exist locally, depending on whether it's an unlogged
* or permanent relation (or if DEBUG_COMPARE_LOCAL is set). Try to
* or permanent relation (or if debug_compare_local is set). Try to
* unlink, it won't do any harm if the file doesn't exist.
*/
mdunlink(rinfo, forkNum, isRedo);
@@ -973,10 +970,11 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
lfc_write(InfoFromSMgrRel(reln), forkNum, blkno, buffer);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
}
/*
* smgr_extend is often called with an all-zeroes page, so
@@ -1051,10 +1049,11 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
relpath(reln->smgr_rlocator, forkNum),
InvalidBlockNumber)));
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
}
/* Don't log any pages if we're not allowed to do so. */
if (!XLogInsertAllowed())
@@ -1265,10 +1264,11 @@ neon_writeback(SMgrRelation reln, ForkNumber forknum,
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
}
}
/*
@@ -1282,7 +1282,6 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
communicator_read_at_lsnv(rinfo, forkNum, blkno, &request_lsns, &buffer, 1, NULL);
}
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void* buffer, XLogRecPtr request_lsn)
{
@@ -1364,7 +1363,6 @@ compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, voi
}
}
}
#endif
#if PG_MAJORVERSION_NUM < 17
@@ -1417,22 +1415,28 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
if (communicator_prefetch_lookupv(InfoFromSMgrRel(reln), forkNum, blkno, &request_lsns, 1, &bufferp, &present))
{
/* Prefetch hit */
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH)
{
return;
}
}
/* Try to read from local file cache */
if (lfc_read(InfoFromSMgrRel(reln), forkNum, blkno, buffer))
{
MyNeonCounters->file_cache_hits_total++;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC)
{
return;
}
}
neon_read_at_lsn(InfoFromSMgrRel(reln), forkNum, blkno, request_lsns, buffer);
@@ -1442,15 +1446,15 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
*/
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#endif
if (debug_compare_local)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
}
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void** buffers, BlockNumber nblocks, neon_request_lsns* request_lsns, bits8* read_pages)
{
@@ -1465,7 +1469,6 @@ compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, vo
}
}
}
#endif
static void
@@ -1516,13 +1519,19 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
blocknum, request_lsns, nblocks,
buffers, read_pages);
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
if (prefetch_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH && prefetch_result == nblocks)
{
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_PREFETCH)
{
memset(read_pages, 0, sizeof(read_pages));
}
/* Try to read from local file cache */
lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers,
@@ -1531,14 +1540,19 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
if (lfc_result > 0)
MyNeonCounters->file_cache_hits_total += lfc_result;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
/* Read all blocks from LFC, so we're done */
if (prefetch_result + lfc_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC && prefetch_result + lfc_result == nblocks)
{
/* Read all blocks from LFC, so we're done */
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_LFC)
{
memset(read_pages, 0, sizeof(read_pages));
}
communicator_read_at_lsnv(InfoFromSMgrRel(reln), forknum, blocknum, request_lsns,
buffers, nblocks, read_pages);
@@ -1548,14 +1562,14 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
*/
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
#endif
if (debug_compare_local)
{
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
}
#endif
#ifdef DEBUG_COMPARE_LOCAL
static char *
hexdump_page(char *page)
{
@@ -1574,7 +1588,6 @@ hexdump_page(char *page)
return result.data;
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
@@ -1596,12 +1609,8 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
#if PG_MAJORVERSION_NUM >= 17
@@ -1656,14 +1665,17 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
{
#if PG_MAJORVERSION_NUM >= 17
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
#else
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
#endif
#endif
}
}
}
#endif
@@ -1677,12 +1689,8 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
@@ -1720,10 +1728,11 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
}
}
#endif
@@ -1862,10 +1871,11 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber old_blocks, Blo
*/
neon_set_lwlsn_relation(lsn, InfoFromSMgrRel(reln), forknum);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
}
}
/*
@@ -1904,10 +1914,11 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#if PG_MAJORVERSION_NUM >= 17
@@ -1934,10 +1945,11 @@ neon_registersync(SMgrRelation reln, ForkNumber forknum)
neon_log(SmgrTrace, "[NEON_SMGR] registersync noop");
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#endif
@@ -1978,10 +1990,11 @@ neon_start_unlogged_build(SMgrRelation reln)
case RELPERSISTENCE_UNLOGGED:
unlogged_build_rel_info = InfoFromSMgrRel(reln);
unlogged_build_phase = UNLOGGED_BUILD_NOT_PERMANENT;
#ifdef DEBUG_COMPARE_LOCAL
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
{
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
}
return;
default:
@@ -2009,11 +2022,7 @@ neon_start_unlogged_build(SMgrRelation reln)
*/
if (!IsParallelWorker())
{
#ifndef DEBUG_COMPARE_LOCAL
mdcreate(reln, MAIN_FORKNUM, false);
#else
mdcreate(reln, INIT_FORKNUM, true);
#endif
mdcreate(reln, debug_compare_local ? INIT_FORKNUM : MAIN_FORKNUM, false);
}
}
@@ -2107,14 +2116,14 @@ neon_end_unlogged_build(SMgrRelation reln)
lfc_invalidate(InfoFromNInfoB(rinfob), forknum, nblocks);
mdclose(reln, forknum);
#ifndef DEBUG_COMPARE_LOCAL
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
#endif
if (!debug_compare_local)
{
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
}
}
#ifdef DEBUG_COMPARE_LOCAL
mdunlink(rinfob, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
mdunlink(rinfob, INIT_FORKNUM, true);
}
NRelFileInfoInvalidate(unlogged_build_rel_info);
unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;

View File

@@ -5,8 +5,9 @@ edition = "2024"
license.workspace = true
[features]
default = []
default = ["rest_broker"]
testing = ["dep:tokio-postgres"]
rest_broker = ["subzero-core", "jsonpath_lib", "ouroboros"]
[dependencies]
ahash.workspace = true
@@ -103,6 +104,9 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
subzero-core = { git = "https://github.com/neondatabase-labs/subzero", rev = "0b3d3278f5f9ac9311a7280cb1676de80e021f06", features = ["postgresql"], optional = true }
jsonpath_lib = { version = "0.3.0", optional = true }
ouroboros = { version = "0.18", optional = true }
# jwt stuff
jose-jwa = "0.1.2"

View File

@@ -138,3 +138,69 @@ Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full"
```
## auth broker setup:
Create a postgres instance:
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_HOST_AUTH_METHOD=trust \
--env POSTGRES_USER=authenticated \
--env POSTGRES_DB=database \
--publish 5432:5432 \
postgres:17-bookworm
```
Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs)
```sh
{
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json",
"provider_name": "foo",
"jwt_audience": null
}
]
}
```
Start the local proxy:
```sh
cargo run --bin local_proxy --features testing -- \
--disable-pg-session-jwt \
--http 0.0.0.0:7432
```
Start the auth broker:
```sh
LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \
-c server.crt -k server.key \
--is-auth-broker true \
--is-rest-broker true \
--wss 0.0.0.0:8080 \
--http 0.0.0.0:7002 \
--auth-backend local
```
Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable.
```sh
export NEON_JWT="..."
```
Run a query against the auth broker:
```sh
curl -k "https://foo.local.neon.build:8080/sql" \
-H "Authorization: Bearer $NEON_JWT" \
-H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \
-d '{"query":"select 1","params":[]}'
```
Make a rest request against the auth broker (rest broker):
```sh
curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \
-H "Authorization: Bearer $NEON_JWT"
```

View File

@@ -164,21 +164,20 @@ async fn authenticate(
})?
.map_err(ConsoleRedirectError::from)?;
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
if auth_config.ip_allowlist_check_enabled
&& let Some(allowed_ips) = &db_info.allowed_ips
&& !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed {
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
if let Some(public_access_allowed) = db_info.public_access_allowed
&& !public_access_allowed
{
return Err(auth::AuthError::NetworkNotAllowed);
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));

View File

@@ -399,36 +399,36 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience {
if payload.audience.0.iter().all(|s| s != aud) {
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
if let Some(aud) = expected_audience
&& payload.audience.0.iter().all(|s| s != aud)
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(exp) = payload.expiration
&& now >= exp + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
if let Some(nbf) = payload.not_before
&& nbf >= now + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))

View File

@@ -171,7 +171,6 @@ impl ComputeUserInfo {
pub(crate) enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -346,15 +345,13 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
if e.is_password_failed()
&& let ControlPlaneClient::ProxyV1(api) = &*api
&& let Some(ep) = &user_info.endpoint_id
{
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
Err(e)

View File

@@ -1,43 +1,40 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use anyhow::bail;
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use camino::Utf8PathBuf;
use clap::Parser;
use compute_api::spec::LocalProxySpec;
use futures::future::Either;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
use crate::tls::client_config::compute_client_config_with_root_certs;
use crate::types::RoleName;
use crate::url::ApiUrl;
project_git_version!(GIT_VERSION);
@@ -82,6 +79,11 @@ struct LocalProxyCliArgs {
/// Path of the local proxy PID file
#[clap(long, default_value = "./local_proxy.pid")]
pid_path: Utf8PathBuf,
/// Disable pg_session_jwt extension installation
/// This is useful for testing the local proxy with vanilla postgres.
#[clap(long, default_value = "false")]
#[cfg(feature = "testing")]
disable_pg_session_jwt: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -277,11 +279,18 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
accept_jwts: true,
console_redirect_confirmation_timeout: Duration::ZERO,
},
#[cfg(feature = "rest_broker")]
rest_config: RestConfig {
is_rest_broker: false,
db_schema_cache: None,
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: args.disable_pg_session_jwt,
})))
}
@@ -293,132 +302,3 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati
Box::leak(Box::new(auth_backend))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -10,11 +10,15 @@ use std::time::Duration;
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
@@ -22,9 +26,15 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
#[cfg(any(test, feature = "testing"))]
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -39,6 +49,8 @@ use crate::redis::{elasticache, notifications};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
@@ -60,6 +72,9 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[cfg(any(test, feature = "testing"))]
Local,
}
/// Neon proxy/router
@@ -74,6 +89,10 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
#[cfg(any(test, feature = "testing"))]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -226,6 +245,16 @@ struct ProxyCliArgs {
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
#[cfg(feature = "rest_broker")]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
#[cfg(feature = "rest_broker")]
db_schema_cache: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -386,6 +415,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
#[cfg(any(test, feature = "testing"))]
let refresh_config_notify = Arc::new(Notify::new());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -412,6 +443,17 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
#[cfg(any(test, feature = "testing"))]
if let auth::Backend::Local(_) = &auth_backend {
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -462,7 +504,13 @@ pub async fn run() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), {
move || {
#[cfg(any(test, feature = "testing"))]
refresh_config_notify.notify_one();
}
}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -473,59 +521,67 @@ pub async fn run() -> anyhow::Result<()> {
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
// add a task to flush the db_schema cache every 10 minutes
#[cfg(feature = "rest_broker")]
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
maintenance_tasks.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(600)).await;
db_schema_cache.flush();
}
});
}
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {
@@ -643,6 +699,28 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout: Duration::from_secs(2),
};
#[cfg(feature = "rest_broker")]
let rest_config = {
let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?;
info!("Using DbSchemaCache with options={db_schema_cache_config:?}");
let db_schema_cache = if args.is_rest_broker {
Some(DbSchemaCache::new(
"db_schema_cache",
db_schema_cache_config.size,
db_schema_cache_config.ttl,
true,
))
} else {
None
};
RestConfig {
is_rest_broker: args.is_rest_broker,
db_schema_cache,
}
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -653,6 +731,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
#[cfg(feature = "rest_broker")]
rest_config,
};
let config = Box::leak(Box::new(config));
@@ -806,6 +888,19 @@ fn build_auth_backend(
Ok(Either::Right(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(
crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)),
);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}

View File

@@ -204,6 +204,10 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
self.insert_raw_ttl(key, value, ttl, false);
}
pub(crate) fn insert(&self, key: K, value: V) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval);
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (_, old) = self.insert_raw(key.clone(), value);
@@ -214,6 +218,28 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
(old, cached)
}
pub(crate) fn flush(&self) {
let now = Instant::now();
let mut cache = self.cache.lock();
// Collect keys of expired entries first
let expired_keys: Vec<_> = cache
.iter()
.filter_map(|(key, entry)| {
if entry.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
// Remove expired entries
for key in expired_keys {
cache.remove(&key);
}
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {

View File

@@ -165,7 +165,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
ComputeCredentialKeys::JwtPayload(_) => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,

View File

@@ -4,28 +4,43 @@ use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use compute_api::spec::LocalProxySpec;
use remote_storage::RemoteStorageConfig;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
use crate::types::{Host, RoleName};
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
#[cfg(feature = "rest_broker")]
pub rest_config: RestConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
#[cfg(feature = "testing")]
pub disable_pg_session_jwt: bool,
}
pub struct ComputeConfig {
@@ -69,6 +84,12 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[cfg(feature = "rest_broker")]
pub struct RestConfig {
pub is_rest_broker: bool,
pub db_schema_cache: Option<DbSchemaCache>,
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
@@ -409,6 +430,135 @@ impl FromStr for ConcurrencyLockOptions {
}
}
#[derive(Error, Debug)]
pub(crate) enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
std::result::Result::Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
pub(crate) async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
std::result::Result::Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -209,11 +209,9 @@ impl RequestContext {
if let Some(options_str) = options.get("options") {
// If not found directly, try to extract it from the options string
for option in options_str.split_whitespace() {
if option.starts_with("neon_query_id:") {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
}
}

View File

@@ -250,10 +250,8 @@ impl NeonControlPlaneClient {
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
let Some((host, port)) = parse_host_port(&body.address) else {
return Err(WakeComputeError::BadComputeAddress(body.address));
};
let host_addr = IpAddr::from_str(host).ok();

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
use tracing::debug;
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
@@ -234,7 +234,8 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
// therefore releasing it is safe from race conditions
info!(
debug!(
//FIXME: is anything depending on this being info?
name = self.name,
shard = i,
"performing epoch reclamation on api lock"

View File

@@ -271,18 +271,18 @@ where
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res {
if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
})) {
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
@@ -583,10 +583,11 @@ impl EventFormatter {
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name() {
if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(task_id) = tokio::task::try_id() {
@@ -596,10 +597,10 @@ impl EventFormatter {
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path() {
if module != meta.target() {
serializer.serialize_entry("module", module)?;
}
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.serialize_entry("module", module)?;
}
if let Some(file) = meta.file() {

View File

@@ -236,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {

View File

@@ -90,27 +90,27 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
if let TransferState::Done(_) = compute_to_client
&& let TransferState::Running(buf) = &client_to_compute
{
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
if let TransferState::Done(_) = client_to_compute
&& let TransferState::Running(buf) = &compute_to_client
{
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
// It is not a problem if ready! returns early ... (comment remains the same)

View File

@@ -39,7 +39,11 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc(now);
}

View File

@@ -211,7 +211,11 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc();
}

View File

@@ -1,79 +0,0 @@
use core::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
_peer_addr: IpAddr,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id, peer_addr)
.await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id, peer_addr).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id, peer_addr)
.await
}
}

View File

@@ -1,11 +1,12 @@
use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -31,7 +32,7 @@ pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
// TODO: with more load on the connection, we should consider using a connection pool
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<()>>,
refresh_token_task: Option<AbortHandle>,
mutex: tokio::sync::Mutex<()>,
credentials_refreshed: Arc<AtomicBool>,
}
@@ -121,16 +122,12 @@ impl ConnectionWithCredentialsProvider {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let credentials_refreshed = self.credentials_refreshed.clone();
let f = tokio::spawn(async move {
let result = Self::keep_connection(con2, credentials_provider).await;
if let Err(e) = result {
credentials_refreshed.store(false, Ordering::Release);
debug!("keep_connection failed: {e}");
} else {
credentials_refreshed.store(true, Ordering::Release);
}
});
self.refresh_token_task = Some(f);
let f = tokio::spawn(Self::keep_connection(
con2,
credentials_provider,
credentials_refreshed,
));
self.refresh_token_task = Some(f.abort_handle());
}
match Self::ping(&mut con).await {
Ok(()) => {
@@ -165,6 +162,7 @@ impl ConnectionWithCredentialsProvider {
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client)
}
@@ -180,16 +178,19 @@ impl ConnectionWithCredentialsProvider {
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
credentials_refreshed: Arc<AtomicBool>,
) -> ! {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
credentials_refreshed.store(true, Ordering::Relaxed);
}
Err(e) => {
error!("Error during token refresh: {e:?}");
credentials_refreshed.store(false, Ordering::Relaxed);
}
}
}
@@ -243,7 +244,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
self.send_packed_command(cmd).boxed()
}
fn req_packed_commands<'a>(
@@ -252,10 +253,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
self.send_packed_commands(cmd, offset, count).boxed()
}
fn get_db(&self) -> i64 {
0
self.con.as_ref().map_or(0, |c| c.get_db())
}
}

View File

@@ -1,4 +1,3 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod keys;

View File

@@ -54,9 +54,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Self::Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
let mut cbind_input = format!("p={mode},,",).into_bytes();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
}

View File

@@ -107,7 +107,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -87,13 +87,20 @@ impl<'a> ClientFirstMessage<'a> {
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::with_capacity(128);
message.push_str("r=");
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
// write combined nonce
let combined_nonce_start = message.len();
message.push_str(self.nonce);
BASE64_STANDARD.encode_string(nonce, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
let combined_nonce = combined_nonce_start..message.len();
// write salt and iterations
message.push_str(",s=");
message.push_str(salt_base64);
message.push_str(",i=");
message.push_str(itoa::Buffer::new().format(iterations));
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -14,7 +14,7 @@ pub(crate) struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
pub(crate) salt_base64: String,
pub(crate) salt_base64: Box<str>,
/// Hashed `ClientKey`.
pub(crate) stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -35,7 +35,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -58,7 +58,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -88,7 +88,7 @@ mod tests {
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(&*parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);

View File

@@ -137,7 +137,7 @@ impl Future for JobSpec {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick % SKETCH_RESET_INTERVAL == 0 {
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
state.countmin.reset();
}

View File

@@ -115,7 +115,8 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
self.config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -129,7 +130,7 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
keys,
})
}
crate::auth::Backend::Local(_) => {
@@ -256,6 +257,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
disable_pg_session_jwt: bool,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
@@ -277,7 +279,7 @@ impl PoolingBackend {
.expect("semaphore should never be closed");
// check again for race
if !self.local_pool.initialized(&conn_info) {
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
@@ -313,14 +315,16 @@ impl PoolingBackend {
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
.dbname(&conn_info.dbname);
if !disable_pg_session_jwt {
config.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
@@ -345,7 +349,9 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if let Err(e) = client.batch_execute("select auth.init();").await {
if !disable_pg_session_jwt
&& let Err(e) = client.batch_execute("select auth.init();").await
{
discard.discard();
return Err(e.into());
}

View File

@@ -148,11 +148,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -1,5 +1,93 @@
use http::StatusCode;
use http::header::HeaderName;
use crate::auth::ComputeUserInfoParseError;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::ReadBodyError;
pub trait HttpCodeError {
fn get_http_status_code(&self) -> StatusCode;
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}

View File

@@ -2,6 +2,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
@@ -21,8 +23,9 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
@@ -237,10 +240,10 @@ pub(crate) fn poll_http2_client(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
{
info!("closed connection removed");
}
}
.instrument(span),

View File

@@ -3,11 +3,42 @@
use anyhow::Context;
use bytes::Bytes;
use http::{Response, StatusCode};
use http::header::AUTHORIZATION;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::types::{DbName, EndpointId, RoleName};
// Common header names used across serverless modules
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
HeaderName::from_static("neon-batch-isolation-level");
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
@@ -107,3 +138,136 @@ pub(crate) fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
pub(crate) fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: Option<&str>,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_url = match connection_string {
Some(connection_string) => Url::parse(connection_string)?,
None => {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
Url::parse(connection_string)?
}
};
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}

View File

@@ -249,11 +249,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -11,6 +11,8 @@ mod http_conn_pool;
mod http_util;
mod json;
mod local_conn_pool;
#[cfg(feature = "rest_broker")]
pub mod rest;
mod sql_over_http;
mod websocket;
@@ -29,13 +31,13 @@ use futures::future::{Either, select};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::SeedableRng;
use rand::rngs::StdRng;
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
@@ -487,6 +489,37 @@ async fn request_handler(
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
#[cfg(feature = "rest_broker")]
{
if config.rest_config.is_rest_broker && {
let path_parts: Vec<&str> = request.uri().path().split('/').collect();
path_parts.len() >= 3 && path_parts[2].starts_with("rest")
} {
let ctx =
RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request
.headers()
.get("X-Neon-Query-ID")
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string());
if let Some(query_id) = testodrome_id {
info!(parent: &ctx.span(), "testodrome query ID: {query_id}");
ctx.set_testodrome_id(query_id.into());
}
rest::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
#[cfg(not(feature = "rest_broker"))]
{
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
}

1192
proxy/src/serverless/rest.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
@@ -24,26 +24,23 @@ use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{Level, debug, error, info};
use typed_json::json;
use url::Url;
use uuid::Uuid;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool::AuthData;
use super::conn_pool_lib::{self, ConnInfo};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
use super::http_util::{
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
};
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::ComputeUserInfoParseError;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::http::read_body_with_limit;
use crate::metrics::{HttpDirection, Metrics};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, EndpointId, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
@@ -70,16 +67,6 @@ enum Payload {
Batch(BatchQueryData),
}
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
@@ -91,179 +78,6 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -532,45 +346,6 @@ impl HttpCodeError for SqlOverHttpError {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -661,7 +436,7 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(&config.authentication_config, ctx, request.headers())?;
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
@@ -747,9 +522,17 @@ async fn handle_db_inner(
ComputeCredentialKeys::JwtPayload(payload)
if backend.auth_backend.is_local_proxy() =>
{
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
#[cfg(feature = "testing")]
let disable_pg_session_jwt = config.disable_pg_session_jwt;
#[cfg(not(feature = "testing"))]
let disable_pg_session_jwt = false;
let mut client = backend
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
.await?;
if !disable_pg_session_jwt {
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
}
_ => {
@@ -848,12 +631,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&TXN_DEFERRABLE,
];
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
async fn handle_auth_broker_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -883,7 +660,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body)
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress

View File

@@ -199,27 +199,27 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
if let Some(ctx) = ctx
&& ctx.get_testodrome_id().is_some()
{
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.

View File

@@ -18,9 +18,10 @@ use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_SSL_CERT_FILE,
DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES, DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
DEFAULT_PARTIAL_BACKUP_CONCURRENCY, DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR,
DEFAULT_SSL_CERT_FILE, DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
};
use safekeeper::wal_backup::WalBackup;
use safekeeper::{
@@ -138,6 +139,15 @@ struct Args {
/// Safekeeper won't be elected for WAL offloading if it is lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_OFFLOADER_LAG_BYTES)]
max_offloader_lag: u64,
/* BEGIN_HADRON */
/// Safekeeper will re-elect a new offloader if the current backup lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES)]
max_reelect_offloader_lag_bytes: u64,
/// Safekeeper will stop accepting new WALs if the timeline disk usage exceeds this value in bytes.
/// Setting this value to 0 disables the limit.
#[arg(long, default_value_t = DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES)]
max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
/// Number of max parallel WAL segments to be offloaded to remote storage.
#[arg(long, default_value = "5")]
wal_backup_parallel_jobs: usize,
@@ -391,6 +401,10 @@ async fn main() -> anyhow::Result<()> {
peer_recovery_enabled: args.peer_recovery,
remote_storage: args.remote_storage,
max_offloader_lag_bytes: args.max_offloader_lag,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: args.max_reelect_offloader_lag_bytes,
max_timeline_disk_usage_bytes: args.max_timeline_disk_usage_bytes,
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
pg_auth,

View File

@@ -17,6 +17,7 @@ use utils::crashsafe::durable_rename;
use crate::control_file_upgrade::{downgrade_v10_to_v9, upgrade_control_file};
use crate::metrics::PERSIST_CONTROL_FILE_SECONDS;
use crate::metrics::WAL_DISK_IO_ERRORS;
use crate::state::{EvictionState, TimelinePersistentState};
pub const SK_MAGIC: u32 = 0xcafeceefu32;
@@ -192,11 +193,14 @@ impl TimelinePersistentState {
impl Storage for FileStorage {
/// Persists state durably to the underlying storage.
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
// start timer for metrics
let _timer = PERSIST_CONTROL_FILE_SECONDS.start_timer();
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!(
"failed to create partial control file at: {}",
&control_partial_path
@@ -206,14 +210,24 @@ impl Storage for FileStorage {
let buf: Vec<u8> = s.write_to_buf()?;
control_partial.write_all(&buf).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to write safekeeper state into control file at: {control_partial_path}")
})?;
control_partial.flush().await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to flush safekeeper state into control file at: {control_partial_path}")
})?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
durable_rename(&control_partial_path, &control_path, !self.no_sync).await?;
durable_rename(&control_partial_path, &control_path, !self.no_sync)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
// update internal state
self.state = s.clone();

View File

@@ -61,6 +61,13 @@ pub mod defaults {
pub const DEFAULT_HEARTBEAT_TIMEOUT: &str = "5000ms";
pub const DEFAULT_MAX_OFFLOADER_LAG_BYTES: u64 = 128 * (1 << 20);
/* BEGIN_HADRON */
// Default leader re-elect is 0(disabled). SK will re-elect leader if the current leader is lagging this many bytes.
pub const DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES: u64 = 0;
// Default disk usage limit is 0 (disabled). It means each timeline by default can use up to this many WAL
// disk space on this SK until SK begins to reject WALs.
pub const DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES: u64 = 0;
/* END_HADRON */
pub const DEFAULT_PARTIAL_BACKUP_TIMEOUT: &str = "15m";
pub const DEFAULT_CONTROL_FILE_SAVE_INTERVAL: &str = "300s";
pub const DEFAULT_PARTIAL_BACKUP_CONCURRENCY: &str = "5";
@@ -99,6 +106,10 @@ pub struct SafeKeeperConf {
pub peer_recovery_enabled: bool,
pub remote_storage: Option<RemoteStorageConfig>,
pub max_offloader_lag_bytes: u64,
/* BEGIN_HADRON */
pub max_reelect_offloader_lag_bytes: u64,
pub max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub pg_auth: Option<Arc<JwtAuth>>,
@@ -151,6 +162,10 @@ impl SafeKeeperConf {
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: defaults::DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES,
max_timeline_disk_usage_bytes: defaults::DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
/* END_HADRON */
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_timeout: Duration::from_secs(0),

View File

@@ -58,6 +58,25 @@ pub static FLUSH_WAL_SECONDS: Lazy<Histogram> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_flush_wal_seconds histogram")
});
/* BEGIN_HADRON */
pub static WAL_DISK_IO_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_disk_io_errors",
"Number of disk I/O errors when creating and flushing WALs and control files"
)
.expect("Failed to register safekeeper_wal_disk_io_errors counter")
});
pub static WAL_STORAGE_LIMIT_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_storage_limit_errors",
concat!(
"Number of errors due to timeline WAL storage utilization exceeding configured limit. ",
"An increase in this metric indicates issues backing up or removing WALs."
)
)
.expect("Failed to register safekeeper_wal_storage_limit_errors counter")
});
/* END_HADRON */
pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_persist_control_file_seconds",
@@ -138,6 +157,15 @@ pub static BACKUP_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_backup_errors_total counter")
});
/* BEGIN_HADRON */
pub static BACKUP_REELECT_LEADER_COUNT: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_backup_reelect_leader_total",
"Number of times the backup leader was reelected"
)
.expect("Failed to register safekeeper_backup_reelect_leader_total counter")
});
/* END_HADRON */
pub static BROKER_PUSH_ALL_UPDATES_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_broker_push_update_seconds",

View File

@@ -16,7 +16,7 @@ use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinHandle;
use tokio::time::MissedTickBehavior;
use tracing::{Instrument, error, info, info_span};
use utils::critical;
use utils::critical_timeline;
use utils::lsn::Lsn;
use utils::postgres_client::{Compression, InterpretedFormat};
use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
@@ -268,6 +268,8 @@ impl InterpretedWalReader {
let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
let ttid = wal_stream.ttid;
let reader = InterpretedWalReader {
wal_stream,
shard_senders: HashMap::from([(
@@ -300,7 +302,11 @@ impl InterpretedWalReader {
.inspect_err(|err| match err {
// TODO: we may want to differentiate these errors further.
InterpretedWalReaderError::Decode(_) => {
critical!("failed to decode WAL record: {err:?}");
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to read WAL record: {err:?}"
);
}
err => error!("failed to read WAL record: {err}"),
})
@@ -363,9 +369,14 @@ impl InterpretedWalReader {
metric.dec();
}
let ttid = self.wal_stream.ttid;
match self.run_impl(start_pos).await {
Err(err @ InterpretedWalReaderError::Decode(_)) => {
critical!("failed to decode WAL record: {err:?}");
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to decode WAL record: {err:?}"
);
}
Err(err) => error!("failed to read WAL record: {err}"),
Ok(()) => info!("interpreted wal reader exiting"),

View File

@@ -26,7 +26,9 @@ use utils::id::{NodeId, TenantId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::sync::gate::Gate;
use crate::metrics::{FullTimelineInfo, MISC_OPERATION_SECONDS, WalStorageMetrics};
use crate::metrics::{
FullTimelineInfo, MISC_OPERATION_SECONDS, WAL_STORAGE_LIMIT_ERRORS, WalStorageMetrics,
};
use crate::rate_limit::RateLimiter;
use crate::receive_wal::WalReceivers;
use crate::safekeeper::{AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, TermLsn};
@@ -1050,6 +1052,39 @@ impl WalResidentTimeline {
Ok(ss)
}
// BEGIN HADRON
// Check if disk usage by WAL segment files for this timeline exceeds the configured limit.
fn hadron_check_disk_usage(
&self,
shared_state_locked: &mut WriteGuardSharedState<'_>,
) -> Result<()> {
// The disk usage is calculated based on the number of segments between `last_removed_segno`
// and the current flush LSN segment number. `last_removed_segno` is advanced after
// unneeded WAL files are physically removed from disk (see `update_wal_removal_end()`
// in `timeline_manager.rs`).
let max_timeline_disk_usage_bytes = self.conf.max_timeline_disk_usage_bytes;
if max_timeline_disk_usage_bytes > 0 {
let last_removed_segno = self.last_removed_segno.load(Ordering::Relaxed);
let flush_lsn = shared_state_locked.sk.flush_lsn();
let wal_seg_size = shared_state_locked.sk.state().server.wal_seg_size as u64;
let current_segno = flush_lsn.segment_number(wal_seg_size as usize);
let segno_count = current_segno - last_removed_segno;
let disk_usage_bytes = segno_count * wal_seg_size;
if disk_usage_bytes > max_timeline_disk_usage_bytes {
WAL_STORAGE_LIMIT_ERRORS.inc();
bail!(
"WAL storage utilization exceeds configured limit of {} bytes: current disk usage: {} bytes",
max_timeline_disk_usage_bytes,
disk_usage_bytes
);
}
}
Ok(())
}
// END HADRON
/// Pass arrived message to the safekeeper.
pub async fn process_msg(
&self,
@@ -1062,6 +1097,13 @@ impl WalResidentTimeline {
let mut rmsg: Option<AcceptorProposerMessage>;
{
let mut shared_state = self.write_shared_state().await;
// BEGIN HADRON
// Errors from the `hadron_check_disk_usage()` function fail the process_msg() function, which
// gets propagated upward and terminates the entire WalAcceptor. This will cause postgres to
// disconnect from the safekeeper and reestablish another connection. Postgres will keep retrying
// safekeeper connections every second until it can successfully propose WAL to the SK again.
self.hadron_check_disk_usage(&mut shared_state)?;
// END HADRON
rmsg = shared_state.sk.safekeeper().process_msg(msg).await?;
// if this is AppendResponse, fill in proper hot standby feedback.

View File

@@ -26,7 +26,9 @@ use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::{backoff, pausable_failpoint};
use crate::metrics::{BACKED_UP_SEGMENTS, BACKUP_ERRORS, WAL_BACKUP_TASKS};
use crate::metrics::{
BACKED_UP_SEGMENTS, BACKUP_ERRORS, BACKUP_REELECT_LEADER_COUNT, WAL_BACKUP_TASKS,
};
use crate::timeline::WalResidentTimeline;
use crate::timeline_manager::{Manager, StateSnapshot};
use crate::{SafeKeeperConf, WAL_BACKUP_RUNTIME};
@@ -70,8 +72,9 @@ pub(crate) async fn update_task(
need_backup: bool,
state: &StateSnapshot,
) {
let (offloader, election_dbg_str) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
/* BEGIN_HADRON */
let (offloader, election_dbg_str) = hadron_determine_offloader(mgr, state);
/* END_HADRON */
let elected_me = Some(mgr.conf.my_id) == offloader;
let should_task_run = need_backup && elected_me;
@@ -127,6 +130,70 @@ async fn shut_down_task(entry: &mut Option<WalBackupTaskHandle>) {
}
}
/* BEGIN_HADRON */
// On top of the neon determine_offloader, we also check if the current offloader is lagging behind too much.
// If it is, we re-elect a new offloader. This mitigates the below issue. It also helps distribute the load across SKs.
//
// We observe that the offloader fails to upload a segment due to race conditions on XLOG SWITCH and PG start streaming WALs.
// wal_backup task continously failing to upload a full segment while the segment remains partial on the disk.
// The consequence is that commit_lsn for all SKs move forward but backup_lsn stays the same. Then, all SKs run out of disk space.
// See go/sk-ood-xlog-switch for more details.
//
// To mitigate this issue, we will re-elect a new offloader if the current offloader is lagging behind too much.
// Each SK makes the decision locally but they are aware of each other's commit and backup lsns.
//
// determine_offloader will pick a SK. say SK-1.
// Each SK checks
// -- if commit_lsn - back_lsn > threshold,
// -- -- remove SK-1 from the candidate and call determine_offloader again.
// SK-1 will step down and all SKs will elect the same leader again.
// After the backup is caught up, the leader will become SK-1 again.
fn hadron_determine_offloader(mgr: &Manager, state: &StateSnapshot) -> (Option<NodeId>, String) {
let mut offloader: Option<NodeId>;
let mut election_dbg_str: String;
let caughtup_peers_count: usize;
(offloader, election_dbg_str, caughtup_peers_count) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
if offloader.is_none()
|| caughtup_peers_count <= 1
|| mgr.conf.max_reelect_offloader_lag_bytes == 0
{
return (offloader, election_dbg_str);
}
let offloader_sk_id = offloader.unwrap();
let backup_lag = state.commit_lsn.checked_sub(state.backup_lsn);
if backup_lag.is_none() {
info!("Backup lag is None. Skipping re-election.");
return (offloader, election_dbg_str);
}
let backup_lag = backup_lag.unwrap().0;
if backup_lag < mgr.conf.max_reelect_offloader_lag_bytes {
return (offloader, election_dbg_str);
}
info!(
"Electing a new leader: Backup lag is too high backup lsn lag {} threshold {}: {}",
backup_lag, mgr.conf.max_reelect_offloader_lag_bytes, election_dbg_str
);
BACKUP_REELECT_LEADER_COUNT.inc();
// Remove the current offloader if lag is too high.
let new_peers: Vec<_> = state
.peers
.iter()
.filter(|p| p.sk_id != offloader_sk_id)
.cloned()
.collect();
(offloader, election_dbg_str, _) =
determine_offloader(&new_peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
(offloader, election_dbg_str)
}
/* END_HADRON */
/// The goal is to ensure that normally only one safekeepers offloads. However,
/// it is fine (and inevitable, as s3 doesn't provide CAS) that for some short
/// time we have several ones as they PUT the same files. Also,
@@ -141,13 +208,13 @@ fn determine_offloader(
wal_backup_lsn: Lsn,
ttid: TenantTimelineId,
conf: &SafeKeeperConf,
) -> (Option<NodeId>, String) {
) -> (Option<NodeId>, String, usize) {
// TODO: remove this once we fill newly joined safekeepers since backup_lsn.
let capable_peers = alive_peers
.iter()
.filter(|p| p.local_start_lsn <= wal_backup_lsn);
match capable_peers.clone().map(|p| p.commit_lsn).max() {
None => (None, "no connected peers to elect from".to_string()),
None => (None, "no connected peers to elect from".to_string(), 0),
Some(max_commit_lsn) => {
let threshold = max_commit_lsn
.checked_sub(conf.max_offloader_lag_bytes)
@@ -175,6 +242,7 @@ fn determine_offloader(
capable_peers_dbg,
caughtup_peers.len()
),
caughtup_peers.len(),
)
}
}
@@ -346,6 +414,8 @@ async fn backup_lsn_range(
anyhow::bail!("parallel_jobs must be >= 1");
}
pausable_failpoint!("backup-lsn-range-pausable");
let remote_timeline_path = &timeline.remote_path;
let start_lsn = *backup_lsn;
let segments = get_segments(start_lsn, end_lsn, wal_seg_size);

View File

@@ -1,15 +1,15 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use safekeeper_api::Term;
use utils::lsn::Lsn;
use crate::send_wal::EndWatch;
use crate::timeline::WalResidentTimeline;
use crate::wal_storage::WalReader;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use safekeeper_api::Term;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
#[derive(PartialEq, Eq, Debug)]
pub(crate) struct WalBytes {
@@ -37,6 +37,8 @@ struct PositionedWalReader {
pub(crate) struct StreamingWalReader {
stream: BoxStream<'static, WalOrReset>,
start_changed_tx: tokio::sync::watch::Sender<Lsn>,
// HADRON: Added TenantTimelineId for instrumentation purposes.
pub(crate) ttid: TenantTimelineId,
}
pub(crate) enum WalOrReset {
@@ -63,6 +65,7 @@ impl StreamingWalReader {
buffer_size: usize,
) -> Self {
let (start_changed_tx, start_changed_rx) = tokio::sync::watch::channel(start);
let ttid = tli.ttid;
let state = WalReaderStreamState {
tli,
@@ -107,6 +110,7 @@ impl StreamingWalReader {
Self {
stream,
start_changed_tx,
ttid,
}
}

View File

@@ -31,7 +31,8 @@ use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use crate::metrics::{
REMOVED_WAL_SEGMENTS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics, time_io_closure,
REMOVED_WAL_SEGMENTS, WAL_DISK_IO_ERRORS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics,
time_io_closure,
};
use crate::state::TimelinePersistentState;
use crate::wal_backup::{WalBackup, read_object, remote_timeline_path};
@@ -293,9 +294,12 @@ impl PhysicalStorage {
// half initialized segment, first bake it under tmp filename and
// then rename.
let tmp_path = self.timeline_dir.join("waltmp");
let file = File::create(&tmp_path)
.await
.with_context(|| format!("Failed to open tmp wal file {:?}", &tmp_path))?;
let file: File = File::create(&tmp_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/* END_HADRON */
format!("Failed to open tmp wal file {:?}", &tmp_path)
})?;
fail::fail_point!("sk-zero-segment", |_| {
info!("sk-zero-segment failpoint hit");
@@ -382,7 +386,11 @@ impl PhysicalStorage {
let flushed = self
.write_in_segment(segno, xlogoff, &buf[..bytes_write])
.await?;
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
self.write_lsn += bytes_write as u64;
if flushed {
self.flush_lsn = self.write_lsn;
@@ -491,7 +499,11 @@ impl Storage for PhysicalStorage {
}
if let Some(unflushed_file) = self.file.take() {
self.fdatasync_file(&unflushed_file).await?;
self.fdatasync_file(&unflushed_file)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
self.file = Some(unflushed_file);
} else {
// We have unflushed data (write_lsn != flush_lsn), but no file. This

View File

@@ -159,6 +159,10 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: 0,
max_timeline_disk_usage_bytes: 0,
/* END_HADRON */
wal_backup_enabled: false,
listen_pg_addr_tenant_only: None,
advertise_pg_addr: None,

View File

@@ -0,0 +1 @@
ALTER TABLE safekeepers ALTER COLUMN scheduling_policy SET DEFAULT 'pause';

View File

@@ -0,0 +1 @@
ALTER TABLE safekeepers ALTER COLUMN scheduling_policy SET DEFAULT 'activating';

View File

@@ -76,6 +76,9 @@ pub(crate) struct StorageControllerMetricGroup {
/// How many shards would like to reconcile but were blocked by concurrency limits
pub(crate) storage_controller_pending_reconciles: measured::Gauge,
/// How many shards are keep-failing and will be ignored when considering to run optimizations
pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge,
/// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status:
measured::CounterVec<HttpRequestStatusLabelGroupSet>,

View File

@@ -1388,6 +1388,48 @@ impl Persistence {
.await
}
/// Activate the given safekeeper, ensuring that there is no TOCTOU.
/// Returns `Some` if the safekeeper has indeed been activating (or already active). Other states return `None`.
pub(crate) async fn activate_safekeeper(&self, id_: i64) -> Result<Option<()>, DatabaseError> {
use crate::schema::safekeepers::dsl::*;
self.with_conn(move |conn| {
Box::pin(async move {
#[derive(Insertable, AsChangeset)]
#[diesel(table_name = crate::schema::safekeepers)]
struct UpdateSkSchedulingPolicy<'a> {
id: i64,
scheduling_policy: &'a str,
}
let scheduling_policy_active = String::from(SkSchedulingPolicy::Active);
let scheduling_policy_activating = String::from(SkSchedulingPolicy::Activating);
let rows_affected = diesel::update(
safekeepers.filter(id.eq(id_)).filter(
scheduling_policy
.eq(scheduling_policy_activating)
.or(scheduling_policy.eq(&scheduling_policy_active)),
),
)
.set(scheduling_policy.eq(&scheduling_policy_active))
.execute(conn)
.await?;
if rows_affected == 0 {
return Ok(Some(()));
}
if rows_affected != 1 {
return Err(DatabaseError::Logical(format!(
"unexpected number of rows ({rows_affected})",
)));
}
Ok(Some(()))
})
})
.await
}
/// Persist timeline. Returns if the timeline was newly inserted. If it wasn't, we haven't done any writes.
pub(crate) async fn insert_timeline(&self, entry: TimelinePersistence) -> DatabaseResult<bool> {
use crate::schema::timelines;

View File

@@ -31,8 +31,8 @@ use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability,
NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy,
ShardSchedulingPolicy, ShardsPreferredAzsRequest, ShardsPreferredAzsResponse,
TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard, TenantDescribeResponse,
TenantDescribeResponseShard, TenantLocateResponse, TenantPolicyRequest,
SkSchedulingPolicy, TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard,
TenantDescribeResponse, TenantDescribeResponseShard, TenantLocateResponse, TenantPolicyRequest,
TenantShardMigrateRequest, TenantShardMigrateResponse,
};
use pageserver_api::models::{
@@ -210,6 +210,10 @@ pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
// Number of consecutive reconciliation errors, occured for one shard,
// after which the shard is ignored when considering to run optimizations.
const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5;
// Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately.
// This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly
// than they're being pushed onto the queue.
@@ -702,6 +706,36 @@ struct ShardMutationLocations {
#[derive(Default, Clone)]
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>);
struct ReconcileAllResult {
spawned_reconciles: usize,
keep_failing_reconciles: usize,
has_delayed_reconciles: bool,
}
impl ReconcileAllResult {
fn new(
spawned_reconciles: usize,
keep_failing_reconciles: usize,
has_delayed_reconciles: bool,
) -> Self {
assert!(
spawned_reconciles >= keep_failing_reconciles,
"It is impossible to have more keep-failing reconciles than spawned reconciles"
);
Self {
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
}
}
/// We can run optimizations only if we don't have any delayed reconciles and
/// all spawned reconciles are also keep-failing reconciles.
fn can_run_optimizations(&self) -> bool {
!self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles
}
}
impl Service {
pub fn get_config(&self) -> &Config {
&self.config
@@ -899,7 +933,7 @@ impl Service {
// which require it: under normal circumstances this should only include tenants that were in some
// transient state before we restarted, or any tenants whose compute hooks failed above.
tracing::info!("Checking for shards in need of reconciliation...");
let reconcile_tasks = self.reconcile_all();
let reconcile_all_result = self.reconcile_all();
// We will not wait for these reconciliation tasks to run here: we're now done with startup and
// normal operations may proceed.
@@ -947,8 +981,9 @@ impl Service {
}
}
let spawned_reconciles = reconcile_all_result.spawned_reconciles;
tracing::info!(
"Startup complete, spawned {reconcile_tasks} reconciliation tasks ({shard_count} shards total)"
"Startup complete, spawned {spawned_reconciles} reconciliation tasks ({shard_count} shards total)"
);
}
@@ -1199,8 +1234,8 @@ impl Service {
while !self.reconcilers_cancel.is_cancelled() {
tokio::select! {
_ = interval.tick() => {
let reconciles_spawned = self.reconcile_all();
if reconciles_spawned == 0 {
let reconcile_all_result = self.reconcile_all();
if reconcile_all_result.can_run_optimizations() {
// Run optimizer only when we didn't find any other work to do
self.optimize_all().await;
}
@@ -1214,7 +1249,7 @@ impl Service {
}
/// Heartbeat all storage nodes once in a while.
#[instrument(skip_all)]
async fn spawn_heartbeat_driver(&self) {
async fn spawn_heartbeat_driver(self: &Arc<Self>) {
self.startup_complete.clone().wait().await;
let mut interval = tokio::time::interval(self.config.heartbeat_interval);
@@ -1341,18 +1376,51 @@ impl Service {
}
}
if let Ok(deltas) = res_sk {
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
for (id, state) in deltas.0 {
let Some(sk) = safekeepers.get_mut(&id) else {
tracing::info!(
"Couldn't update safekeeper safekeeper state for id {id} from heartbeat={state:?}"
);
continue;
};
sk.set_availability(state);
let mut to_activate = Vec::new();
{
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
for (id, state) in deltas.0 {
let Some(sk) = safekeepers.get_mut(&id) else {
tracing::info!(
"Couldn't update safekeeper safekeeper state for id {id} from heartbeat={state:?}"
);
continue;
};
if sk.scheduling_policy() == SkSchedulingPolicy::Activating
&& let SafekeeperState::Available { .. } = state
{
to_activate.push(id);
}
sk.set_availability(state);
}
locked.safekeepers = Arc::new(safekeepers);
}
for sk_id in to_activate {
// TODO this can race with set_scheduling_policy (can create disjoint DB <-> in-memory state)
tracing::info!("Activating safekeeper {sk_id}");
match self.persistence.activate_safekeeper(sk_id.0 as i64).await {
Ok(Some(())) => {}
Ok(None) => {
tracing::info!(
"safekeeper {sk_id} has been removed from db or has different scheduling policy than active or activating"
);
}
Err(e) => {
tracing::warn!("couldn't apply activation of {sk_id} to db: {e}");
continue;
}
}
if let Err(e) = self
.set_safekeeper_scheduling_policy_in_mem(sk_id, SkSchedulingPolicy::Active)
.await
{
tracing::info!("couldn't activate safekeeper {sk_id} in memory: {e}");
continue;
}
tracing::info!("Activation of safekeeper {sk_id} done");
}
locked.safekeepers = Arc::new(safekeepers);
}
}
}
@@ -1408,6 +1476,7 @@ impl Service {
match result.result {
Ok(()) => {
tenant.consecutive_errors_count = 0;
tenant.apply_observed_deltas(deltas);
tenant.waiter.advance(result.sequence);
}
@@ -1426,6 +1495,8 @@ impl Service {
}
}
tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1);
// Ordering: populate last_error before advancing error_seq,
// so that waiters will see the correct error after waiting.
tenant.set_last_error(result.sequence, e);
@@ -8026,7 +8097,7 @@ impl Service {
/// Returns how many reconciliation tasks were started, or `1` if no reconciles were
/// spawned but some _would_ have been spawned if `reconciler_concurrency` units where
/// available. A return value of 0 indicates that everything is fully reconciled already.
fn reconcile_all(&self) -> usize {
fn reconcile_all(&self) -> ReconcileAllResult {
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, scheduler) = locked.parts_mut();
let pageservers = nodes.clone();
@@ -8034,13 +8105,16 @@ impl Service {
// This function is an efficient place to update lazy statistics, since we are walking
// all tenants.
let mut pending_reconciles = 0;
let mut keep_failing_reconciles = 0;
let mut az_violations = 0;
// If we find any tenants to drop from memory, stash them to offload after
// we're done traversing the map of tenants.
let mut drop_detached_tenants = Vec::new();
let mut reconciles_spawned = 0;
let mut spawned_reconciles = 0;
let mut has_delayed_reconciles = false;
for shard in tenants.values_mut() {
// Accumulate scheduling statistics
if let (Some(attached), Some(preferred)) =
@@ -8060,18 +8134,32 @@ impl Service {
// If there is something delayed, then return a nonzero count so that
// callers like reconcile_all_now do not incorrectly get the impression
// that the system is in a quiescent state.
reconciles_spawned = std::cmp::max(1, reconciles_spawned);
has_delayed_reconciles = true;
pending_reconciles += 1;
continue;
}
// Eventual consistency: if an earlier reconcile job failed, and the shard is still
// dirty, spawn another one
let consecutive_errors_count = shard.consecutive_errors_count;
if self
.maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal)
.is_some()
{
reconciles_spawned += 1;
spawned_reconciles += 1;
// Count shards that are keep-failing. We still want to reconcile them
// to avoid a situation where a shard is stuck.
// But we don't want to consider them when deciding to run optimizations.
if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS {
tracing::warn!(
tenant_id=%shard.tenant_shard_id.tenant_id,
shard_id=%shard.tenant_shard_id.shard_slug(),
"Shard reconciliation is keep-failing: {} errors",
consecutive_errors_count
);
keep_failing_reconciles += 1;
}
} else if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
@@ -8110,7 +8198,16 @@ impl Service {
.storage_controller_pending_reconciles
.set(pending_reconciles as i64);
reconciles_spawned
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_keep_failing_reconciles
.set(keep_failing_reconciles as i64);
ReconcileAllResult::new(
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
)
}
/// `optimize` in this context means identifying shards which have valid scheduled locations, but
@@ -8783,13 +8880,13 @@ impl Service {
/// also wait for any generated Reconcilers to complete. Calling this until it returns zero should
/// put the system into a quiescent state where future background reconciliations won't do anything.
pub(crate) async fn reconcile_all_now(&self) -> Result<usize, ReconcileWaitError> {
let reconciles_spawned = self.reconcile_all();
let reconciles_spawned = if reconciles_spawned == 0 {
let reconcile_all_result = self.reconcile_all();
let mut spawned_reconciles = reconcile_all_result.spawned_reconciles;
if reconcile_all_result.can_run_optimizations() {
// Only optimize when we are otherwise idle
self.optimize_all().await
} else {
reconciles_spawned
};
let optimization_reconciles = self.optimize_all().await;
spawned_reconciles += optimization_reconciles;
}
let waiters = {
let mut waiters = Vec::new();
@@ -8826,11 +8923,11 @@ impl Service {
tracing::info!(
"{} reconciles in reconcile_all, {} waiters",
reconciles_spawned,
spawned_reconciles,
waiter_count
);
Ok(std::cmp::max(waiter_count, reconciles_spawned))
Ok(std::cmp::max(waiter_count, spawned_reconciles))
}
async fn stop_reconciliations(&self, reason: StopReconciliationsReason) {

View File

@@ -236,40 +236,30 @@ impl Service {
F: std::future::Future<Output = mgmt_api::Result<T>> + Send + 'static,
T: Sync + Send + 'static,
{
let target_sk_count = safekeepers.len();
if target_sk_count == 0 {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"timeline configured without any safekeepers"
)));
}
if target_sk_count < self.config.timeline_safekeeper_count {
tracing::warn!(
"running a quorum operation with {} safekeepers, which is less than configured {} safekeepers per timeline",
target_sk_count,
self.config.timeline_safekeeper_count
);
}
let results = self
.tenant_timeline_safekeeper_op(safekeepers, op, timeout)
.await?;
// Now check if quorum was reached in results.
let target_sk_count = safekeepers.len();
let quorum_size = match target_sk_count {
0 => {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"timeline configured without any safekeepers",
)));
}
1 | 2 => {
#[cfg(feature = "testing")]
{
// In test settings, it is allowed to have one or two safekeepers
target_sk_count
}
#[cfg(not(feature = "testing"))]
{
// The region is misconfigured: we need at least three safekeepers to be configured
// in order to schedule work to them
tracing::warn!(
"couldn't find at least 3 safekeepers for timeline, found: {:?}",
target_sk_count
);
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"couldn't find at least 3 safekeepers to put timeline to"
)));
}
}
_ => target_sk_count / 2 + 1,
};
let quorum_size = target_sk_count / 2 + 1;
let success_count = results.iter().filter(|res| res.is_ok()).count();
if success_count < quorum_size {
// Failure
@@ -815,7 +805,7 @@ impl Service {
Safekeeper::from_persistence(
crate::persistence::SafekeeperPersistence::from_upsert(
record,
SkSchedulingPolicy::Pause,
SkSchedulingPolicy::Activating,
),
CancellationToken::new(),
use_https,
@@ -856,27 +846,36 @@ impl Service {
.await?;
let node_id = NodeId(id as u64);
// After the change has been persisted successfully, update the in-memory state
{
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
let sk = safekeepers
.get_mut(&node_id)
.ok_or(DatabaseError::Logical("Not found".to_string()))?;
sk.set_scheduling_policy(scheduling_policy);
self.set_safekeeper_scheduling_policy_in_mem(node_id, scheduling_policy)
.await
}
match scheduling_policy {
SkSchedulingPolicy::Active => {
locked
.safekeeper_reconcilers
.start_reconciler(node_id, self);
}
SkSchedulingPolicy::Decomissioned | SkSchedulingPolicy::Pause => {
locked.safekeeper_reconcilers.stop_reconciler(node_id);
}
pub(crate) async fn set_safekeeper_scheduling_policy_in_mem(
self: &Arc<Service>,
node_id: NodeId,
scheduling_policy: SkSchedulingPolicy,
) -> Result<(), DatabaseError> {
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
let sk = safekeepers
.get_mut(&node_id)
.ok_or(DatabaseError::Logical("Not found".to_string()))?;
sk.set_scheduling_policy(scheduling_policy);
match scheduling_policy {
SkSchedulingPolicy::Active => {
locked
.safekeeper_reconcilers
.start_reconciler(node_id, self);
}
SkSchedulingPolicy::Decomissioned
| SkSchedulingPolicy::Pause
| SkSchedulingPolicy::Activating => {
locked.safekeeper_reconcilers.stop_reconciler(node_id);
}
locked.safekeepers = Arc::new(safekeepers);
}
locked.safekeepers = Arc::new(safekeepers);
Ok(())
}

View File

@@ -131,6 +131,15 @@ pub(crate) struct TenantShard {
#[serde(serialize_with = "read_last_error")]
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>,
/// Number of consecutive reconciliation errors that have occurred for this shard.
///
/// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard
/// will be countered as keep-failing in `reconcile_all` calculations. This will lead to
/// allowing optimizations to run even with some failing shards.
///
/// The counter is reset to 0 after a successful reconciliation.
pub(crate) consecutive_errors_count: usize,
/// If we have a pending compute notification that for some reason we weren't able to send,
/// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes
/// and trigger a Reconciler run. This is the mechanism by which compute notifications are included in the scope
@@ -594,6 +603,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence(0))),
error_waiter: Arc::new(SeqWait::new(Sequence(0))),
last_error: Arc::default(),
consecutive_errors_count: 0,
pending_compute_notification: false,
scheduling_policy: ShardSchedulingPolicy::default(),
preferred_node: None,
@@ -1859,6 +1869,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence::initial())),
error_waiter: Arc::new(SeqWait::new(Sequence::initial())),
last_error: Arc::default(),
consecutive_errors_count: 0,
pending_compute_notification: false,
delayed_reconcile: false,
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),

View File

@@ -989,6 +989,102 @@ def test_storage_controller_compute_hook_retry(
)
@run_only_on_default_postgres("postgres behavior is not relevant")
def test_storage_controller_compute_hook_keep_failing(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
):
neon_env_builder.num_pageservers = 4
neon_env_builder.storage_controller_config = {"use_local_compute_notifications": False}
(host, port) = httpserver_listen_address
neon_env_builder.control_plane_hooks_api = f"http://{host}:{port}"
# Set up CP handler for compute notifications
status_by_tenant: dict[TenantId, int] = {}
def handler(request: Request):
notify_request = request.json
assert notify_request is not None
status = status_by_tenant[TenantId(notify_request["tenant_id"])]
log.info(f"Notify request[{status}]: {notify_request}")
return Response(status=status)
httpserver.expect_request("/notify-attach", method="PUT").respond_with_handler(handler)
# Run neon environment
env = neon_env_builder.init_configs()
env.start()
# Create two tenants:
# - The first tenant is banned by CP and contains only one shard
# - The second tenant is allowed by CP and contains four shards
banned_tenant = TenantId.generate()
status_by_tenant[banned_tenant] = 200 # we will ban this tenant later
env.create_tenant(banned_tenant, placement_policy='{"Attached": 1}')
shard_count = 4
allowed_tenant = TenantId.generate()
status_by_tenant[allowed_tenant] = 200
env.create_tenant(allowed_tenant, shard_count=shard_count, placement_policy='{"Attached": 1}')
# Find the pageserver of the banned tenant
banned_tenant_ps = env.get_tenant_pageserver(banned_tenant)
assert banned_tenant_ps is not None
alive_pageservers = [p for p in env.pageservers if p.id != banned_tenant_ps.id]
# Stop pageserver and ban tenant to trigger failed reconciliation
status_by_tenant[banned_tenant] = 423
banned_tenant_ps.stop()
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
# Migrate all allowed tenant shards to the first alive pageserver
# to trigger storage controller optimizations due to affinity rules
for shard_number in range(shard_count):
env.storage_controller.tenant_shard_migrate(
TenantShardId(allowed_tenant, shard_number, shard_count),
alive_pageservers[0].id,
config=StorageControllerMigrationConfig(prewarm=False, override_scheduler=True),
)
# Make some reconcile_all calls to trigger optimizations
# RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS
RECONCILE_COUNT = 12
for i in range(RECONCILE_COUNT):
try:
n = env.storage_controller.reconcile_all()
log.info(f"Reconciliation attempt {i} finished with success: {n}")
except StorageControllerApiException as e:
assert "Control plane tenant busy" in str(e)
log.info(f"Reconciliation attempt {i} finished with failure")
banned_descr = env.storage_controller.tenant_describe(banned_tenant)
assert banned_descr["shards"][0]["is_pending_compute_notification"] is True
time.sleep(2)
# Check that the allowed tenant shards are optimized due to affinity rules
locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"]
not_optimized_shard_count = 0
for loc in locations:
tsi = TenantShardId.parse(loc[0])
if tsi.tenant_id != allowed_tenant:
continue
if loc[1]["mode"] == "AttachedSingle":
not_optimized_shard_count += 1
log.info(f"Shard {tsi} seen in mode {loc[1]['mode']}")
assert not_optimized_shard_count < shard_count, "At least one shard should be optimized"
# Unban the tenant and run reconciliations
status_by_tenant[banned_tenant] = 200
env.storage_controller.reconcile_all()
banned_descr = env.storage_controller.tenant_describe(banned_tenant)
assert banned_descr["shards"][0]["is_pending_compute_notification"] is False
@run_only_on_default_postgres("this test doesn't start an endpoint")
def test_storage_controller_compute_hook_revert(
httpserver: HTTPServer,
@@ -3530,18 +3626,21 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder):
# some small tests for the scheduling policy querying and returning APIs
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
target.safekeeper_scheduling_policy(inserted["id"], "Active")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
# Ensure idempotency
target.safekeeper_scheduling_policy(inserted["id"], "Active")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
# change back to paused again
assert (
newest_info["scheduling_policy"] == "Activating"
or newest_info["scheduling_policy"] == "Active"
)
target.safekeeper_scheduling_policy(inserted["id"], "Pause")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
# Ensure idempotency
target.safekeeper_scheduling_policy(inserted["id"], "Pause")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
# change back to active again
target.safekeeper_scheduling_policy(inserted["id"], "Active")
def storcon_heartbeat():
assert env.storage_controller.log_contains(
@@ -3554,6 +3653,57 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder):
target.safekeeper_scheduling_policy(inserted["id"], "Decomissioned")
@run_only_on_default_postgres("this is like a 'unit test' against storcon db")
def test_safekeeper_activating_to_active(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_configs()
env.start()
fake_id = 5
target = env.storage_controller
assert target.get_safekeeper(fake_id) is None
start_sks = target.get_safekeepers()
sk_0 = env.safekeepers[0]
body = {
"active": True,
"id": fake_id,
"created_at": "2023-10-25T09:11:25Z",
"updated_at": "2024-08-28T11:32:43Z",
"region_id": "aws-eu-central-1",
"host": "localhost",
"port": sk_0.port.pg,
"http_port": sk_0.port.http,
"https_port": None,
"version": 5957,
"availability_zone_id": "eu-central-1a",
}
target.on_safekeeper_deploy(fake_id, body)
inserted = target.get_safekeeper(fake_id)
assert inserted is not None
assert target.get_safekeepers() == start_sks + [inserted]
assert eq_safekeeper_records(body, inserted)
def safekeeper_is_active():
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
wait_until(safekeeper_is_active)
target.safekeeper_scheduling_policy(inserted["id"], "Activating")
wait_until(safekeeper_is_active)
# Now decomission it
target.safekeeper_scheduling_policy(inserted["id"], "Decomissioned")
def eq_safekeeper_records(a: dict[str, Any], b: dict[str, Any]) -> bool:
compared = [dict(a), dict(b)]

View File

@@ -1889,6 +1889,31 @@ def test_timeline_detach_with_aux_files_with_detach_v1(
assert set(http.list_aux_files(env.initial_tenant, branch_timeline_id, lsn1).keys()) == set([])
def test_detach_ancestors_with_no_writes(
neon_env_builder: NeonEnvBuilder,
):
env = neon_env_builder.init_start()
endpoint = env.endpoints.create_start("main", tenant_id=env.initial_tenant)
wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline)
endpoint.safe_psql(
"SELECT pg_create_logical_replication_slot('test_slot_parent_1', 'pgoutput')"
)
wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline)
endpoint.stop()
for i in range(0, 5):
if i == 0:
ancestor_name = "main"
else:
ancestor_name = f"b{i}"
tlid = env.create_branch(f"b{i + 1}", ancestor_branch_name=ancestor_name)
client = env.pageserver.http_client()
client.detach_ancestor(tenant_id=env.initial_tenant, timeline_id=tlid)
# TODO:
# - branch near existing L1 boundary, image layers?
# - investigate: why are layers started at uneven lsn? not just after branching, but in general.

View File

@@ -2740,3 +2740,85 @@ def test_pull_timeline_partial_segment_integrity(neon_env_builder: NeonEnvBuilde
raise Exception("Uneviction did not happen on source safekeeper yet")
wait_until(unevicted)
def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
"""
Test that the timeline disk usage circuit breaker works as expected. We test that:
1. The circuit breaker kicks in when the timeline's disk usage exceeds the configured limit,
and it causes writes to hang.
2. The hanging writes unblock when the issue resolves (by restarting the safekeeper in the
test to simulate a more realistic production troubleshooting scenario).
3. We can continue to write as normal after the issue resolves.
4. There is no data corruption throughout the test.
"""
# Set up environment with a very small disk usage limit (1KB)
neon_env_builder.num_safekeepers = 1
remote_storage_kind = s3_storage()
neon_env_builder.enable_safekeeper_remote_storage(remote_storage_kind)
# Set a very small disk usage limit (1KB)
neon_env_builder.safekeeper_extra_opts = ["--max-timeline-disk-usage-bytes=1024"]
env = neon_env_builder.init_start()
# Create a timeline and endpoint
env.create_branch("test_timeline_disk_usage_limit")
endpoint = env.endpoints.create_start("test_timeline_disk_usage_limit")
# Get the safekeeper
sk = env.safekeepers[0]
# Inject a failpoint to stop WAL backup
with sk.http_client() as http_cli:
http_cli.configure_failpoints([("backup-lsn-range-pausable", "pause")])
# Write some data that will exceed the 1KB limit. While the failpoint is active, this operation
# will hang as Postgres encounters safekeeper-returned errors and retries.
def run_hanging_insert():
with closing(endpoint.connect()) as bg_conn:
with bg_conn.cursor() as bg_cur:
# This should generate more than 1KB of WAL
bg_cur.execute("create table t(key int, value text)")
bg_cur.execute("insert into t select generate_series(1,2000), 'payload'")
# Start the inserts in a background thread
bg_thread = threading.Thread(target=run_hanging_insert)
bg_thread.start()
# Wait for the error message to appear in the compute log
def error_logged():
return endpoint.log_contains("WAL storage utilization exceeds configured limit") is not None
wait_until(error_logged)
log.info("Found expected error message in compute log, resuming.")
# Sanity check that the hanging insert is indeed still hanging. Otherwise means the circuit breaker we
# implemented didn't work as expected.
time.sleep(2)
assert bg_thread.is_alive(), (
"The hanging insert somehow unblocked without resolving the disk usage issue!"
)
log.info("Restarting the safekeeper to resume WAL backup.")
# Restart the safekeeper with defaults to both clear the failpoint and resume the larger disk usage limit.
for sk in env.safekeepers:
sk.stop().start(extra_opts=[])
# The hanging insert will now complete. Join the background thread so that we can
# verify that the insert completed successfully.
bg_thread.join(timeout=120)
assert not bg_thread.is_alive(), "Hanging insert did not complete after safekeeper restart"
log.info("Hanging insert unblocked.")
# Verify we can continue to write as normal
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("insert into t select generate_series(2001,3000), 'payload'")
# Sanity check data correctness
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("select count(*) from t")
# 2000 rows from first insert + 1000 from last insert
assert cur.fetchone() == (3000,)