Compare commits

..

23 Commits

Author SHA1 Message Date
Konstantin Knizhnik
3fe7f8802a Register reset_min_request_lsn using before_shmem_exit to make it be called before ProcKill which cleans MyProc 2025-07-28 21:59:35 +03:00
Konstantin Knizhnik
7841303df9 Cleanup perf_counters on backend exit 2025-07-28 21:04:15 +03:00
Konstantin Knizhnik
0e9d3fd2b1 Cleanup perf_counters on backend exit 2025-07-28 21:04:15 +03:00
Konstantin Knizhnik
b2a87b501f Fix cleanup of min_request_lsn on backend exit 2025-07-28 21:04:13 +03:00
Konstantin Knizhnik
06417f2ff9 Some minor refactoring addressing review comments 2025-07-28 21:03:50 +03:00
Konstantin Knizhnik
252876515c Rewrite min-request-lsn reset mechanism on backend exit 2025-07-28 21:03:49 +03:00
Konstantin Knizhnik
d56a72afec Reset backend's perf cpounters on exit 2025-07-28 21:03:09 +03:00
Konstantin Knizhnik
3847ab73a7 Cleanup perf_counters on backend exit 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
3e5bbe7027 Address review comments 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
546a45f57a Do flush only iof there are no in-flight prefetch requests 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
588cb289d5 Do flush only iof there are no in-glight prefetch requests 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
b947deb07c Flush requests in prefetch+_pump_state 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
2f455baa73 Correctly handle communicator_reconfigure_timeout in case of replica promotion 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
f2e65e1d2c Remove assert checks from communicator_reconfigure_timeout_if_needed 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
df49def453 Add assert check thsat timeout is registered 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
1ef0f71a95 Always configure prefetch timeout at replicas 2025-07-28 21:02:14 +03:00
Konstantin Knizhnik
96df649858 Update pgxn/neon/pagestore_smgr.c
Co-authored-by: Matthias van de Meent <matthias@neon.tech>
2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
9bfba1b087 Implement new apporach of calculating min in-flight LSN in prefetch_pump_state 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
b41b85f8ec Perform page LSN check only for v3 version of protocol 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
32b801ea1c Fix mistypings 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
89496a32d0 Return end record LSN in log_newpages_copy 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
9617b8d328 Update makefile 2025-07-28 21:02:14 +03:00
Kosntantin Knizhnik
a504516b8c Maintain min in-flight prefetch request LSN 2025-07-28 21:02:14 +03:00
77 changed files with 959 additions and 2413 deletions

View File

@@ -146,9 +146,7 @@ jobs:
with:
file: build-tools/Dockerfile
context: .
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
provenance: false
push: true
pull: true
build-args: |

View File

@@ -634,9 +634,7 @@ jobs:
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
provenance: false
push: true
pull: true
file: Dockerfile
@@ -749,9 +747,7 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
provenance: false
push: true
pull: true
file: compute/compute-node.Dockerfile
@@ -770,9 +766,7 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
provenance: false
push: true
pull: true
file: compute/compute-node.Dockerfile

View File

@@ -72,10 +72,9 @@ jobs:
options: --init --user root
services:
clickhouse:
image: clickhouse/clickhouse-server:25.6
image: clickhouse/clickhouse-server:24.8
env:
CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }}
PGSSLCERT: /tmp/postgresql.crt
ports:
- 9000:9000
- 8123:8123

3
Cargo.lock generated
View File

@@ -5077,6 +5077,8 @@ dependencies = [
"crc32c",
"criterion",
"env_logger",
"log",
"memoffset 0.9.0",
"once_cell",
"postgres",
"postgres_ffi_types",
@@ -5517,7 +5519,6 @@ dependencies = [
"workspace_hack",
"x509-cert",
"zerocopy 0.8.24",
"zeroize",
]
[[package]]

View File

@@ -135,6 +135,7 @@ lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
memoffset = "0.9"
moka = { version = "0.12", features = ["sync"] }
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
# Do not update to >= 7.0.0, at least. The update will have a significant impact
@@ -233,10 +234,9 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"
## TODO replace this with tracing
env_logger = "0.11"

View File

@@ -103,7 +103,7 @@ RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo auditable build \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \

View File

@@ -299,7 +299,6 @@ WORKDIR /home/nonroot
ENV RUSTC_VERSION=1.88.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG CARGO_AUDITABLE_VERSION=0.7.0
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.36
ARG CARGO_DENY_VERSION=0.18.2
@@ -315,16 +314,14 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" && \
cargo auditable install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" --force && \
cargo auditable install rustfilt --version "${RUSTFILT_VERSION}" && \
cargo auditable install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \
cargo auditable install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \
cargo auditable install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \
cargo auditable install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \
cargo auditable install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \
cargo auditable install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \
--features postgres-bundled --no-default-features && \
cargo install rustfilt --locked --version "${RUSTFILT_VERSION}" && \
cargo install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \
cargo install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \
cargo install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \
cargo install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \
cargo install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \
cargo install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

View File

@@ -1,11 +1,5 @@
commit 5eb393810cf7c7bafa4e394dad2e349e2a8cb2cb
Author: Alexey Masterov <alexey.masterov@databricks.com>
Date: Mon Jul 28 18:11:02 2025 +0200
Patch for pg_repack
diff --git a/regress/Makefile b/regress/Makefile
index bf6edcb..110e734 100644
index bf6edcb..89b4c7f 100644
--- a/regress/Makefile
+++ b/regress/Makefile
@@ -17,7 +17,7 @@ INTVERSION := $(shell echo $$(($$(echo $(VERSION).0 | sed 's/\([[:digit:]]\{1,\}
@@ -13,36 +7,18 @@ index bf6edcb..110e734 100644
#
-REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper tablespace get_order_by trigger
+REGRESS := init-extension noautovacuum repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger autovacuum
+REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger
USE_PGXS = 1 # use pgxs if not in contrib directory
PGXS := $(shell $(PG_CONFIG) --pgxs)
diff --git a/regress/expected/autovacuum.out b/regress/expected/autovacuum.out
new file mode 100644
index 0000000..e7f2363
--- /dev/null
+++ b/regress/expected/autovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/noautovacuum.out b/regress/expected/noautovacuum.out
new file mode 100644
index 0000000..fc7978e
--- /dev/null
+++ b/regress/expected/noautovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/init-extension.out b/regress/expected/init-extension.out
index 9f2e171..f6e4f8d 100644
--- a/regress/expected/init-extension.out
+++ b/regress/expected/init-extension.out
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/expected/nosuper.out b/regress/expected/nosuper.out
index 8d0a94e..63b68bf 100644
--- a/regress/expected/nosuper.out
@@ -74,22 +50,14 @@ index 8d0a94e..63b68bf 100644
INFO: repacking table "public.tbl_cluster"
ERROR: query failed: ERROR: current transaction is aborted, commands ignored until end of transaction block
DETAIL: query was: RESET lock_timeout
diff --git a/regress/sql/autovacuum.sql b/regress/sql/autovacuum.sql
new file mode 100644
index 0000000..a8eda63
--- /dev/null
+++ b/regress/sql/autovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
diff --git a/regress/sql/noautovacuum.sql b/regress/sql/noautovacuum.sql
new file mode 100644
index 0000000..13d4836
--- /dev/null
+++ b/regress/sql/noautovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
diff --git a/regress/sql/init-extension.sql b/regress/sql/init-extension.sql
index 9f2e171..f6e4f8d 100644
--- a/regress/sql/init-extension.sql
+++ b/regress/sql/init-extension.sql
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/sql/nosuper.sql b/regress/sql/nosuper.sql
index 072f0fa..dbe60f8 100644
--- a/regress/sql/nosuper.sql

View File

@@ -82,15 +82,6 @@ struct Cli {
#[arg(long, default_value_t = 3081)]
pub internal_http_port: u16,
/// Backwards-compatible --http-port for Hadron deployments. Functionally the
/// same as --external-http-port.
#[arg(
long,
conflicts_with = "external_http_port",
conflicts_with = "internal_http_port"
)]
pub http_port: Option<u16>,
#[arg(short = 'D', long, value_name = "DATADIR")]
pub pgdata: String,
@@ -190,26 +181,6 @@ impl Cli {
}
}
// Hadron helpers to get compatible compute_ctl http ports from Cli. The old `--http-port`
// arg is used and acts the same as `--external-http-port`. The internal http port is defined
// to be http_port + 1. Hadron runs in the dblet environment which uses the host network, so
// we need to be careful with the ports to choose.
fn get_external_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli.http_port.unwrap_or(cli.external_http_port);
}
cli.external_http_port
}
fn get_internal_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli
.http_port
.map(|p| p + 1)
.unwrap_or(cli.internal_http_port);
}
cli.internal_http_port
}
fn main() -> Result<()> {
let cli = Cli::parse();
@@ -234,18 +205,13 @@ fn main() -> Result<()> {
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
if cli.lakebase_mode {
installed_extensions::initialize_metrics();
hadron_metrics::initialize_metrics();
}
installed_extensions::initialize_metrics();
hadron_metrics::initialize_metrics();
let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?;
let config = get_config(&cli)?;
let external_http_port = get_external_http_port(&cli);
let internal_http_port = get_internal_http_port(&cli);
let compute_node = ComputeNode::new(
ComputeNodeParams {
compute_id: cli.compute_id,
@@ -254,8 +220,8 @@ fn main() -> Result<()> {
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
external_http_port,
internal_http_port,
external_http_port: cli.external_http_port,
internal_http_port: cli.internal_http_port,
remote_ext_base_url: cli.remote_ext_base_url.clone(),
resize_swap_on_bind: cli.resize_swap_on_bind,
set_disk_quota_for_fs: cli.set_disk_quota_for_fs,

View File

@@ -6,8 +6,7 @@ use compute_api::responses::{
LfcPrewarmState, PromoteState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption,
PageserverProtocol, PgIdent, Role,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -414,66 +413,6 @@ struct StartVmMonitorResult {
vm_monitor: Option<JoinHandle<Result<()>>>,
}
// BEGIN_HADRON
/// This function creates roles that are used by Databricks.
/// These roles are not needs to be botostrapped at PG Compute provisioning time.
/// The auth method for these roles are configured in databricks_pg_hba.conf in universe repository.
pub(crate) fn create_databricks_roles() -> Vec<String> {
let roles = vec![
// Role for prometheus_stats_exporter
Role {
name: "databricks_monitor".to_string(),
// This uses "local" connection and auth method for that is "trust", so no password is needed.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "IN ROLE pg_monitor".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore control plane
Role {
name: "databricks_control_plane".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "SUPERUSER".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore httpgateway.
Role {
name: "databricks_gateway".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: None,
},
];
roles
.into_iter()
.map(|role| {
let query = format!(
r#"
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM pg_catalog.pg_roles WHERE rolname = '{}')
THEN
CREATE ROLE {} {};
END IF;
END
$$;"#,
role.name,
role.name.pg_quote(),
role.to_pg_options(),
);
query
})
.collect()
}
/// Databricks-specific environment variables to be passed to the `postgres` sub-process.
pub struct DatabricksEnvVars {
/// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check
@@ -482,27 +421,14 @@ pub struct DatabricksEnvVars {
/// Hostname of the Databricks workspace URL this compute instance belongs to.
/// Used by postgres to verify Databricks PAT tokens.
pub workspace_host: String,
pub lakebase_mode: bool,
}
impl DatabricksEnvVars {
pub fn new(
compute_spec: &ComputeSpec,
compute_id: Option<&String>,
instance_id: Option<String>,
lakebase_mode: bool,
) -> Self {
let endpoint_id = if let Some(instance_id) = instance_id {
// Use instance_id as endpoint_id if it is set. This code path is for PuPr model.
instance_id
} else {
// Use compute_id as endpoint_id if instance_id is not set. The code path is for PrPr model.
// compute_id is a string format of "{endpoint_id}/{compute_idx}"
// endpoint_id is a uuid. We only need to pass down endpoint_id to postgres.
// Panics if compute_id is not set or not in the expected format.
compute_id.unwrap().split('/').next().unwrap().to_string()
};
pub fn new(compute_spec: &ComputeSpec, compute_id: Option<&String>) -> Self {
// compute_id is a string format of "{endpoint_id}/{compute_idx}"
// endpoint_id is a uuid. We only need to pass down endpoint_id to postgres.
// Panics if compute_id is not set or not in the expected format.
let endpoint_id = compute_id.unwrap().split('/').next().unwrap().to_string();
let workspace_host = compute_spec
.databricks_settings
.as_ref()
@@ -511,7 +437,6 @@ impl DatabricksEnvVars {
Self {
endpoint_id,
workspace_host,
lakebase_mode,
}
}
@@ -521,10 +446,6 @@ impl DatabricksEnvVars {
/// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`.
pub fn to_env_var_list(self) -> Vec<(String, String)> {
if !self.lakebase_mode {
// In neon env, we don't need to pass down the env vars to postgres.
return vec![];
}
vec![
(
Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(),
@@ -574,11 +495,7 @@ impl ComputeNode {
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
if params.lakebase_mode {
ComputeNode::set_spec(&params, &mut new_state, pspec);
} else {
new_state.pspec = Some(pspec);
}
new_state.pspec = Some(pspec);
}
Ok(ComputeNode {
@@ -1176,14 +1093,7 @@ impl ComputeNode {
// If it is something different then create_dir() will error out anyway.
let pgdata = &self.params.pgdata;
let _ok = fs::remove_dir_all(pgdata);
if self.params.lakebase_mode {
// Ignore creation errors if the directory already exists (e.g. mounting it ahead of time).
// If it is something different then PG startup will error out anyway.
let _ok = fs::create_dir(pgdata);
} else {
fs::create_dir(pgdata)?;
}
fs::create_dir(pgdata)?;
fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
Ok(())
@@ -1662,7 +1572,7 @@ impl ComputeNode {
// symlink doesn't affect anything.
//
// See https://github.com/neondatabase/autoscaling/issues/800
std::fs::remove_dir_all(pgdata_path.join("pg_dynshmem"))?;
std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?;
symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?;
match spec.mode {
@@ -1677,12 +1587,6 @@ impl ComputeNode {
/// Start and stop a postgres process to warm up the VM for startup.
pub fn prewarm_postgres_vm_memory(&self) -> Result<()> {
if self.params.lakebase_mode {
// We are running in Hadron mode. Disabling this prewarming step for now as it could run
// into dblet port conflicts and also doesn't add much value with our current infra.
info!("Skipping postgres prewarming in Hadron mode");
return Ok(());
}
info!("prewarming VM memory");
// Create pgdata
@@ -1744,12 +1648,7 @@ impl ComputeNode {
let databricks_env_vars = {
let state = self.state.lock().unwrap();
let spec = &state.pspec.as_ref().unwrap().spec;
DatabricksEnvVars::new(
spec,
Some(&self.params.compute_id),
self.params.instance_id.clone(),
self.params.lakebase_mode,
)
DatabricksEnvVars::new(spec, Some(&self.params.compute_id))
};
info!(
@@ -1921,15 +1820,7 @@ impl ComputeNode {
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
if self.params.lakebase_mode {
// Set a 2-minute statement_timeout for the session applying config. The individual SQL statements
// used in apply_spec_sql() should not take long (they are just creating users and installing
// extensions). If any of them are stuck for an extended period of time it usually indicates a
// pageserver connectivity problem and we should bail out.
conf.options("-c statement_timeout=2min");
}
let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
let conf = Arc::new(conf);
let spec = Arc::new(
@@ -2247,17 +2138,7 @@ impl ComputeNode {
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
// BEGIN HADRON
// NB: Read core dump files from a fixed location outside of
// the data directory since `compute_ctl` wipes the data directory
// across container restarts.
_ => {
if self.params.lakebase_mode {
Path::new("/databricks/logs/brickstore")
} else {
Path::new(&self.params.pgdata)
}
} // END HADRON
_ => Path::new(&self.params.pgdata),
};
// Collect core dump paths if any
@@ -2570,7 +2451,7 @@ LIMIT 100",
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.filter(|s| *s != "neon" && !s.is_empty())
.map(str::to_string)
.collect();
}
@@ -2589,7 +2470,7 @@ LIMIT 100",
if let Some(libs) = shared_preload_libraries_line.split("='").nth(1) {
preload_libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.filter(|s| *s != "neon" && !s.is_empty())
.map(str::to_string)
.collect();
}

View File

@@ -13,19 +13,17 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState, create_databricks_roles};
use crate::hadron_metrics::COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS;
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
AddDatabricksGrants, AlterDatabricksRoles, CreateAndAlterDatabases, CreateAndAlterRoles,
CreateAvailabilityCheck, CreateDatabricksMisc, CreateDatabricksRoles, CreatePgauditExtension,
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleDatabricksAuthExtension, HandleNeonExtension, HandleOtherExtensions,
RenameAndDeleteDatabases, RenameRoles, RunInEachDatabase,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
};
use crate::spec_apply::PerDatabasePhase::{
ChangeSchemaPerms, DeleteDBRoleReferences, DropLogicalSubscriptions,
@@ -168,7 +166,6 @@ impl ComputeNode {
concurrency_token.clone(),
db,
[DropLogicalSubscriptions].to_vec(),
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -189,33 +186,15 @@ impl ComputeNode {
};
}
let phases = if self.params.lakebase_mode {
vec![
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
for phase in [
CreatePrivilegedRole,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
]
} else {
vec![
CreatePrivilegedRole,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
]
};
for phase in phases {
] {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
@@ -224,7 +203,6 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -276,7 +254,6 @@ impl ComputeNode {
concurrency_token.clone(),
db,
phases,
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -288,28 +265,12 @@ impl ComputeNode {
handle.await??;
}
let mut phases = if self.params.lakebase_mode {
vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
DropRoles,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
]
} else {
vec![
let mut phases = vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
CreateAvailabilityCheck,
DropRoles,
]
};
];
// This step depends on CreateSchemaNeon
if spec.drop_subscriptions_before_start && !drop_subscriptions_done {
@@ -342,7 +303,6 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -368,7 +328,6 @@ impl ComputeNode {
concurrency_token: Arc<tokio::sync::Semaphore>,
db: DB,
subphases: Vec<PerDatabasePhase>,
lakebase_mode: bool,
) -> Result<()> {
let _permit = concurrency_token.acquire().await?;
@@ -396,7 +355,6 @@ impl ComputeNode {
let client = client_conn.as_ref().unwrap();
Ok(client)
},
lakebase_mode,
)
.await?;
}
@@ -519,10 +477,6 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -535,14 +489,7 @@ pub enum ApplySpecPhase {
DisablePostgresDBPgAudit,
HandleOtherExtensions,
HandleNeonExtension,
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
DropRoles,
FinalizeDropLogicalSubscriptions,
}
@@ -578,7 +525,6 @@ pub async fn apply_operations<'a, Fut, F>(
jwks_roles: Arc<HashSet<String>>,
apply_spec_phase: ApplySpecPhase,
client: F,
lakebase_mode: bool,
) -> Result<()>
where
F: FnOnce() -> Fut,
@@ -625,23 +571,6 @@ where
},
query
);
if !lakebase_mode {
return res;
}
// BEGIN HADRON
if let Err(e) = res.as_ref() {
if let Some(sql_state) = e.code() {
if sql_state.code() == "57014" {
// SQL State 57014 (ERRCODE_QUERY_CANCELED) is used for statement timeouts.
// Increment the counter whenever a statement timeout occurs. Timeouts on
// this configuration path can only occur due to PS connectivity problems that
// Postgres failed to recover from.
COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.inc();
}
}
}
// END HADRON
res
}
.instrument(inspan)
@@ -683,35 +612,6 @@ async fn get_operations<'a>(
),
comment: None,
}))),
// BEGIN_HADRON
// New Hadron phase
ApplySpecPhase::CreateDatabricksRoles => {
let queries = create_databricks_roles();
let operations = queries.into_iter().map(|query| Operation {
query,
comment: None,
});
Ok(Box::new(operations))
}
// Backfill existing databricks_reader_* roles with statement timeout from GUC
ApplySpecPhase::AlterDatabricksRoles => {
let query = String::from(include_str!(
"sql/alter_databricks_reader_roles_timeout.sql"
));
let operations = once(Operation {
query,
comment: Some(
"Backfill existing databricks_reader_* roles with statement timeout"
.to_string(),
),
});
Ok(Box::new(operations))
}
// End of new Hadron Phase
// END_HADRON
ApplySpecPhase::DropInvalidDatabases => {
let mut ctx = ctx.write().await;
let databases = &mut ctx.dbs;
@@ -1081,10 +981,7 @@ async fn get_operations<'a>(
// N.B. this has to be properly dollar-escaped with `pg_quote_dollar()`
role_name = escaped_role,
outer_tag = outer_tag,
)
// HADRON change:
.replace("neon_superuser", &params.privileged_role_name),
// HADRON change end ,
),
comment: None,
},
// This now will only drop privileges of the role
@@ -1120,8 +1017,7 @@ async fn get_operations<'a>(
comment: None,
},
Operation {
query: String::from(include_str!("sql/default_grants.sql"))
.replace("neon_superuser", &params.privileged_role_name),
query: String::from(include_str!("sql/default_grants.sql")),
comment: None,
},
]
@@ -1190,28 +1086,6 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// Note: we may want to version the extension someday, but for now we just drop it and recreate it.
ApplySpecPhase::HandleDatabricksAuthExtension => {
let operations = vec![
Operation {
query: String::from("DROP EXTENSION IF EXISTS databricks_auth"),
comment: Some(String::from("dropping existing databricks_auth extension")),
},
Operation {
query: String::from("CREATE EXTENSION databricks_auth"),
comment: Some(String::from("creating databricks_auth extension")),
},
Operation {
query: String::from("GRANT SELECT ON databricks_auth_metrics TO pg_monitor"),
comment: Some(String::from("grant select on databricks auth counters")),
},
]
.into_iter();
Ok(Box::new(operations))
}
// END_HADRON
ApplySpecPhase::CreateAvailabilityCheck => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/add_availabilitycheck_tables.sql")),
comment: None,
@@ -1229,63 +1103,6 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// New Hadron phases
//
// Grants permissions to roles that are used by Databricks.
ApplySpecPhase::AddDatabricksGrants => {
let operations = vec![
Operation {
query: String::from("GRANT USAGE ON SCHEMA neon TO databricks_monitor"),
comment: Some(String::from(
"Permissions needed to execute neon.* functions (in the postgres database)",
)),
},
Operation {
query: String::from(
"GRANT SELECT, INSERT, UPDATE ON health_check TO databricks_monitor",
),
comment: Some(String::from("Permissions needed for read and write probes")),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION pg_ls_dir(text) TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to monitor .snap file counts",
)),
},
Operation {
query: String::from(
"GRANT SELECT ON neon.neon_perf_counters TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to access neon performance counters view",
)),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION neon.get_perf_counters() TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to execute the underlying performance counters function",
)),
},
]
.into_iter();
Ok(Box::new(operations))
}
// Creates minor objects that are used by Databricks.
ApplySpecPhase::CreateDatabricksMisc => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/create_databricks_misc.sql")),
comment: Some(String::from(
"The function databricks_monitor uses to convert exception to 0 or 1",
)),
}))),
// End of new Hadron phases
// END_HADRON
ApplySpecPhase::FinalizeDropLogicalSubscriptions => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/finalize_drop_subscriptions.sql")),
comment: None,

View File

@@ -1,25 +0,0 @@
DO $$
DECLARE
reader_role RECORD;
timeout_value TEXT;
BEGIN
-- Get the current GUC setting for reader statement timeout
SELECT current_setting('databricks.reader_statement_timeout', true) INTO timeout_value;
-- Only proceed if timeout_value is not null/empty and not '0' (disabled)
IF timeout_value IS NOT NULL AND timeout_value != '' AND timeout_value != '0' THEN
-- Find all databricks_reader_* roles and update their statement_timeout
FOR reader_role IN
SELECT r.rolname
FROM pg_roles r
WHERE r.rolname ~ '^databricks_reader_\d+$'
LOOP
-- Apply the timeout setting to the role (will overwrite existing setting)
EXECUTE format('ALTER ROLE %I SET statement_timeout = %L',
reader_role.rolname, timeout_value);
RAISE LOG 'Updated statement_timeout = % for role %', timeout_value, reader_role.rolname;
END LOOP;
END IF;
END
$$;

View File

@@ -1,15 +0,0 @@
ALTER ROLE databricks_monitor SET statement_timeout = '60s';
CREATE OR REPLACE FUNCTION health_check_write_succeeds()
RETURNS INTEGER AS $$
BEGIN
INSERT INTO health_check VALUES (1, now())
ON CONFLICT (id) DO UPDATE
SET updated_at = now();
RETURN 1;
EXCEPTION WHEN OTHERS THEN
RAISE EXCEPTION '[DATABRICKS_SMGR] health_check failed: [%] %', SQLSTATE, SQLERRM;
RETURN 0;
END;
$$ LANGUAGE plpgsql;

View File

@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>()
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
{
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
};
Ok(res)

View File

@@ -72,10 +72,10 @@ impl Server {
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
return true;
}
if let Some(inner) = err.source()
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
{
return suppress_io_error(io);
if let Some(inner) = err.source() {
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
return suppress_io_error(io);
}
}
false
}

View File

@@ -363,7 +363,7 @@ where
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;

View File

@@ -9,7 +9,10 @@ regex.workspace = true
bytes.workspace = true
anyhow.workspace = true
crc32c.workspace = true
criterion.workspace = true
once_cell.workspace = true
log.workspace = true
memoffset.workspace = true
pprof.workspace = true
thiserror.workspace = true
serde.workspace = true
@@ -19,7 +22,6 @@ tracing.workspace = true
postgres_versioninfo.workspace = true
[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
postgres.workspace = true

View File

@@ -34,8 +34,9 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
impl ControlFileData {
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
/// Equivalent to offsetof(ControlFileData, crc) in C.
const fn pg_control_crc_offset() -> usize {
std::mem::offset_of!(ControlFileData, crc)
// Someday this can be const when the right compiler features land.
fn pg_control_crc_offset() -> usize {
memoffset::offset_of!(ControlFileData, crc)
}
///

View File

@@ -4,11 +4,12 @@
use crate::pg_constants;
use crate::transaction_id_precedes;
use bytes::BytesMut;
use log::*;
use super::bindings::MultiXactId;
pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) {
tracing::trace!(
trace!(
"handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)",
status
);

View File

@@ -14,6 +14,7 @@ use super::xlog_utils::*;
use crate::WAL_SEGMENT_SIZE;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crc32c::*;
use log::*;
use std::cmp::min;
use std::num::NonZeroU32;
use utils::lsn::Lsn;
@@ -235,7 +236,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder {
// XLOG_SWITCH records are special. If we see one, we need to skip
// to the next WAL segment.
let next_lsn = if xlogrec.is_xlog_switch_record() {
tracing::trace!("saw xlog switch record at {}", self.lsn);
trace!("saw xlog switch record at {}", self.lsn);
self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64)
} else {
// Pad to an 8-byte boundary

View File

@@ -23,6 +23,8 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ};
use bytes::BytesMut;
use bytes::{Buf, Bytes};
use log::*;
use serde::Serialize;
use std::ffi::{CString, OsStr};
use std::fs::File;
@@ -233,7 +235,7 @@ pub fn find_end_of_wal(
let mut curr_lsn = start_lsn;
let mut buf = [0u8; XLOG_BLCKSZ];
let pg_version = MY_PGVERSION;
tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version);
debug!("find_end_of_wal PG_VERSION: {}", pg_version);
let mut decoder = WalStreamDecoder::new(start_lsn, pg_version);
@@ -245,7 +247,7 @@ pub fn find_end_of_wal(
match open_wal_segment(&seg_file_path)? {
None => {
// no more segments
tracing::debug!(
debug!(
"find_end_of_wal reached end at {:?}, segment {:?} doesn't exist",
result, seg_file_path
);
@@ -258,7 +260,7 @@ pub fn find_end_of_wal(
while curr_lsn.segment_number(wal_seg_size) == segno {
let bytes_read = segment.read(&mut buf)?;
if bytes_read == 0 {
tracing::debug!(
debug!(
"find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}",
result,
seg_file_path,
@@ -274,7 +276,7 @@ pub fn find_end_of_wal(
match decoder.poll_decode() {
Ok(Some(record)) => result = record.0,
Err(e) => {
tracing::debug!(
debug!(
"find_end_of_wal reached end at {:?}, decode error: {:?}",
result, e
);

View File

@@ -9,7 +9,7 @@ use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
pub use self::sqlstate::*;
#[allow(clippy::unreadable_literal)]
pub mod sqlstate;
mod sqlstate;
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]

View File

@@ -49,7 +49,7 @@ impl PerfSpan {
}
}
pub fn enter(&self) -> PerfSpanEntered<'_> {
pub fn enter(&self) -> PerfSpanEntered {
if let Some(ref id) = self.inner.id() {
self.dispatch.enter(id);
}

View File

@@ -14,9 +14,9 @@ use utils::logging::warn_slow;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_page_api as page_api;
use pageserver_page_api::GetPageSplitter;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};

View File

@@ -1,5 +1,6 @@
mod client;
mod pool;
mod retry;
mod split;
pub use client::{PageserverClient, ShardSpec};

View File

@@ -3,18 +3,18 @@ use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use crate::model::*;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::key_to_shard_number;
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// Split requests by shard index.
requests: HashMap<ShardIndex, GetPageRequest>,
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
/// The response being assembled. Preallocated with empty pages, to be filled in.
response: GetPageResponse,
response: page_api::GetPageResponse,
/// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used
/// to assemble the response pages in the same order as the original request.
block_shards: Vec<ShardIndex>,
@@ -24,7 +24,7 @@ impl GetPageSplitter {
/// Checks if the given request only touches a single shard, and returns the shard ID. This is
/// the common case, so we check first in order to avoid unnecessary allocations and overhead.
pub fn for_single_shard(
req: &GetPageRequest,
req: &page_api::GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
@@ -57,7 +57,7 @@ impl GetPageSplitter {
/// Splits the given request.
pub fn split(
req: GetPageRequest,
req: page_api::GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
@@ -84,7 +84,7 @@ impl GetPageSplitter {
requests
.entry(shard_id)
.or_insert_with(|| GetPageRequest {
.or_insert_with(|| page_api::GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
@@ -98,16 +98,16 @@ impl GetPageSplitter {
// Construct a response to be populated by shard responses. Preallocate empty page slots
// with the expected block numbers.
let response = GetPageResponse {
let response = page_api::GetPageResponse {
request_id: req.request_id,
status_code: GetPageStatusCode::Ok,
status_code: page_api::GetPageStatusCode::Ok,
reason: None,
rel: req.rel,
pages: req
.block_numbers
.into_iter()
.map(|block_number| {
Page {
page_api::Page {
block_number,
image: Bytes::new(), // empty page slot to be filled in
}
@@ -123,7 +123,9 @@ impl GetPageSplitter {
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
pub fn drain_requests(&mut self) -> impl Iterator<Item = (ShardIndex, GetPageRequest)> {
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
self.requests.drain()
}
@@ -133,10 +135,10 @@ impl GetPageSplitter {
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: GetPageResponse,
response: page_api::GetPageResponse,
) -> anyhow::Result<()> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != GetPageStatusCode::Ok {
if response.status_code != page_api::GetPageStatusCode::Ok {
return Err(anyhow!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
@@ -207,7 +209,7 @@ impl GetPageSplitter {
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> anyhow::Result<GetPageResponse> {
pub fn get_response(self) -> anyhow::Result<page_api::GetPageResponse> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {

View File

@@ -19,9 +19,7 @@ pub mod proto {
}
mod client;
mod model;
mod split;
pub use client::Client;
mod model;
pub use model::*;
pub use split::GetPageSplitter;

View File

@@ -16,8 +16,7 @@ use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt as _};
use futures::{FutureExt, Stream};
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -36,8 +35,8 @@ use pageserver_api::pagestream_api::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api as page_api;
use pageserver_page_api::proto;
use pageserver_page_api::{self as page_api, GetPageSplitter};
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -467,6 +466,13 @@ impl TimelineHandles {
self.handles
.get(timeline_id, shard_selector, &self.wrapper)
.await
.map_err(|e| match e {
timeline::handle::GetError::TenantManager(e) => e,
timeline::handle::GetError::PerTimelineStateShutDown => {
trace!("per-timeline state shut down");
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
@@ -482,9 +488,11 @@ pub(crate) struct TenantManagerWrapper {
tenant_id: once_cell::sync::OnceCell<TenantId>,
}
#[derive(Debug)]
pub(crate) struct TenantManagerTypes;
impl timeline::handle::Types for TenantManagerTypes {
type TenantManagerError = GetActiveTimelineError;
type TenantManager = TenantManagerWrapper;
type Timeline = TenantManagerCacheItem;
}
@@ -3424,6 +3432,18 @@ impl GrpcPageServiceHandler {
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]
fn ensure_shard_zero(timeline: &Handle<TenantManagerTypes>) -> Result<(), tonic::Status> {
match timeline.get_shard_index().shard_number.0 {
0 => Ok(()),
shard => Err(tonic::Status::invalid_argument(format!(
"request must execute on shard zero (is shard {shard})",
))),
}
}
/// Generates a PagestreamRequest header from a ReadLsn and request ID.
fn make_hdr(
read_lsn: page_api::ReadLsn,
@@ -3438,72 +3458,30 @@ impl GrpcPageServiceHandler {
}
}
/// Acquires a timeline handle for the given request. The shard index must match a local shard.
/// Acquires a timeline handle for the given request.
///
/// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`].
/// TODO: during shard splits, the compute may still be sending requests to the parent shard
/// until the entire split is committed and the compute is notified. Consider installing a
/// temporary shard router from the parent to the children while the split is in progress.
///
/// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage
/// the TimelineHandles lifecycle.
///
/// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid
/// the unnecessary overhead.
async fn get_request_timeline(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, GetActiveTimelineError> {
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let ttid = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
let shard_selector = ShardSelector::Known(shard_index);
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
TimelineHandles::new(self.tenant_manager.clone())
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await
}
/// Acquires a timeline handle for the given request, which must be for shard zero. Most
/// metadata requests are only valid on shard zero.
///
/// NB: during an ongoing shard split, the compute will keep talking to the parent shard until
/// the split is committed, but the parent shard may have been removed in the meanwhile. In that
/// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`].
///
/// TODO: revamp the split protocol to avoid this child routing.
async fn get_request_timeline_shard_zero(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, tonic::Status> {
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
if shard_index.shard_number.0 != 0 {
return Err(tonic::Status::invalid_argument(format!(
"request only valid on shard zero (requested shard {shard_index})",
)));
}
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
Ok(timeline) => Ok(timeline),
Err(err) => {
// We may be in the middle of a shard split. Try to find a child shard 0.
if let Ok(timeline) = handles
.get(tenant_id, timeline_id, ShardSelector::Zero)
.await
&& timeline.get_shard_index().shard_count > shard_index.shard_count
{
return Ok(timeline);
}
Err(err.into())
}
}
}
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
/// Only errors if the timeline is shutting down.
///
@@ -3533,22 +3511,28 @@ impl GrpcPageServiceHandler {
/// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send
/// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or
/// split them up in the client or server.
#[instrument(skip_all, fields(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
))]
#[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))]
async fn get_page(
ctx: &RequestContext,
timeline: Handle<TenantManagerTypes>,
req: page_api::GetPageRequest,
timeline: &WeakHandle<TenantManagerTypes>,
req: proto::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
) -> Result<proto::GetPageResponse, tonic::Status> {
let received_at = Instant::now();
let timeline = timeline.upgrade()?;
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
let req = page_api::GetPageRequest::try_from(req)?;
span_record!(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
);
for &blkno in &req.block_numbers {
let shard = timeline.get_shard_identity();
let key = rel_block_to_key(req.rel, blkno);
@@ -3636,95 +3620,7 @@ impl GrpcPageServiceHandler {
};
}
Ok(resp)
}
/// Processes a GetPage request when there is a potential shard split in progress. We have to
/// reroute the request to any local child shards, and split batch requests that straddle
/// multiple child shards.
///
/// Parent shards are split and removed incrementally (there may be many parent shards when
/// splitting an already-sharded tenant), but the compute is only notified once the overall
/// split commits, which can take several minutes. In the meanwhile, the compute will be sending
/// requests to the parent shards.
///
/// TODO: add test infrastructure to provoke this situation frequently and for long periods of
/// time, to properly exercise it.
///
/// TODO: revamp the split protocol to avoid this, e.g.:
/// * Keep the parent shard until the split commits and the compute is notified.
/// * Notify the compute about each subsplit.
/// * Return an error that updates the compute's shard map.
#[instrument(skip_all)]
#[allow(clippy::too_many_arguments)]
async fn maybe_split_get_page(
ctx: &RequestContext,
handles: &mut TimelineHandles,
tenant_id: TenantId,
timeline_id: TimelineId,
parent: ShardIndex,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
// Check the first page to see if we have any child shards at all. Otherwise, the compute is
// just talking to the wrong Pageserver. If the parent has been split, the shard now owning
// the page must have a higher shard count.
let timeline = handles
.get(
tenant_id,
timeline_id,
ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])),
)
.await?;
let shard_id = timeline.get_shard_identity();
if shard_id.count <= parent.shard_count {
return Err(HandleUpgradeError::ShutDown.into()); // emulate original error
}
// Fast path: the request fits in a single shard.
if let Some(shard_index) =
GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
// We got the shard ID from the first page, so these must be equal.
assert_eq!(shard_index.shard_number, shard_id.number);
assert_eq!(shard_index.shard_count, shard_id.count);
return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await;
}
// The request spans multiple shards; split it and dispatch parallel requests. All pages
// were originally in the parent shard, and during a split all children are local, so we
// expect to find local shards for all pages.
let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_index, shard_req) in splitter.drain_requests() {
let timeline = handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await?;
let future = Self::get_page(
ctx,
timeline,
shard_req,
io_concurrency.clone(),
received_at,
)
.map(move |result| result.map(|resp| (shard_index, resp)));
shard_requests.push(future);
}
while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_index, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
Ok(resp.into())
}
}
@@ -3753,10 +3649,11 @@ impl proto::PageService for GrpcPageServiceHandler {
// to be the sweet spot where throughput is saturated.
const CHUNK_SIZE: usize = 256 * 1024;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate the request and decorate the span.
Self::ensure_shard_zero(&timeline)?;
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
@@ -3872,10 +3769,11 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn);
@@ -3904,29 +3802,14 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// Extract the timeline from the request and check that it exists.
//
// NB: during shard splits, the compute may still send requests to the parent shard. We'll
// reroute requests to the child shards below, but we also detect the common cases here
// where either the shard exists or no shards exist at all. If we have a child shard, we
// can't acquire a weak handle because we don't know which child shard to use yet.
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(&req);
let ttid = *extract::<TenantTimelineId>(&req);
let shard_index = *extract::<ShardIndex>(&req);
let shard_selector = ShardSelector::Known(shard_index);
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
let timeline = match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
// The timeline shard exists. Keep a weak handle to reuse for each request.
Ok(timeline) => Some(timeline.downgrade()),
// The shard doesn't exist, but a child shard does. We'll reroute requests later.
Err(_) if self.tenant_manager.has_child_shard(tenant_id, shard_index) => None,
// Failed to fetch the timeline, and no child shard exists. Error out.
Err(err) => return Err(err.into()),
};
handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?;
// Spawn an IoConcurrency sidecar, if enabled.
let gate_guard = self
@@ -3943,9 +3826,11 @@ impl proto::PageService for GrpcPageServiceHandler {
let mut reqs = req.into_inner();
let resps = async_stream::try_stream! {
let timeline = handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?
.downgrade();
loop {
// Wait for the next client request.
//
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
@@ -3958,44 +3843,16 @@ impl proto::PageService for GrpcPageServiceHandler {
Err(err) => Err(err),
},
}?;
let received_at = Instant::now();
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
// Process the request, using a closure to capture errors.
let process_request = async || {
let req = page_api::GetPageRequest::try_from(req)?;
// Fast path: use the pre-acquired timeline handle.
if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) {
return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at)
.instrument(span.clone()) // propagate request span
.await
}
// The timeline handle is stale. During shard splits, the compute may still be
// sending requests to the parent shard. Try to re-route requests to the child
// shards, and split any batch requests that straddle multiple child shards.
Self::maybe_split_get_page(
&ctx,
&mut handles,
tenant_id,
timeline_id,
shard_index,
req,
io_concurrency.clone(),
received_at,
)
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
.instrument(span.clone()) // propagate request span
.await
};
// Return the response. Convert per-request errors to GetPageResponses if
// appropriate, or terminate the stream with a tonic::Status.
yield match process_request().await {
Ok(resp) => resp.into(),
.await;
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
Err(status) => {
// Log the error, since ObservabilityLayer won't see stream errors.
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
@@ -4015,10 +3872,11 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
@@ -4051,10 +3909,11 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?;
span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn);
@@ -4084,10 +3943,6 @@ impl proto::PageService for GrpcPageServiceHandler {
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
// TODO: this won't work during shard splits, as the request is directed at a specific shard
// but the parent shard is removed before the split commits and the compute is notified
// (which can take several minutes for large tenants). That's also the case for the libpq
// implementation, so we keep the behavior for now.
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);

View File

@@ -826,18 +826,6 @@ impl TenantManager {
peek_slot.is_some()
}
/// Returns whether a local shard exists that's a child of the given tenant shard. Note that
/// this just checks for any shard with a larger shard count, and it may not be a direct child
/// of the given shard (their keyspace may not overlap).
pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
match &*self.tenants.read().unwrap() {
TenantsMap::Initializing => false,
TenantsMap::Open(slots) | TenantsMap::ShuttingDown(slots) => slots
.range(TenantShardId::tenant_range(tenant_id))
.any(|(tsid, _)| tsid.shard_count > shard_index.shard_count),
}
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))]
pub(crate) async fn upsert_location(
&self,
@@ -1534,13 +1522,6 @@ impl TenantManager {
self.resources.deletion_queue_client.flush_advisory();
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
//
// TODO: keeping the parent as InProgress while spawning the children causes read
// unavailability, as we can't acquire a new timeline handle for it (existing handles appear
// to still work though, even downgraded ones). The parent should be available for reads
// until the children are ready -- potentially until *all* subsplits across all parent
// shards are complete and the compute has been notified. See:
// <https://databricks.atlassian.net/browse/LKB-672>.
drop(tenant);
let mut parent_slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;

View File

@@ -224,11 +224,11 @@ use tracing::{instrument, trace};
use utils::id::TimelineId;
use utils::shard::{ShardIndex, ShardNumber};
use crate::page_service::GetActiveTimelineError;
use crate::tenant::GetTimelineError;
use crate::tenant::mgr::{GetActiveTenantError, ShardSelector};
use crate::tenant::mgr::ShardSelector;
pub(crate) trait Types: Sized {
/// The requirement for Debug is so that #[derive(Debug)] works in some places.
pub(crate) trait Types: Sized + std::fmt::Debug {
type TenantManagerError: Sized + std::fmt::Debug;
type TenantManager: TenantManager<Self> + Sized;
type Timeline: Timeline<Self> + Sized;
}
@@ -307,11 +307,12 @@ impl<T: Types> Default for PerTimelineState<T> {
/// Abstract view of [`crate::tenant::mgr`], for testability.
pub(crate) trait TenantManager<T: Types> {
/// Invoked by [`Cache::get`] to resolve a [`ShardTimelineId`] to a [`Types::Timeline`].
/// Errors are returned as [`GetError::TenantManager`].
async fn resolve(
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> Result<T::Timeline, GetActiveTimelineError>;
) -> Result<T::Timeline, T::TenantManagerError>;
}
/// Abstract view of an [`Arc<Timeline>`], for testability.
@@ -321,6 +322,13 @@ pub(crate) trait Timeline<T: Types> {
fn per_timeline_state(&self) -> &PerTimelineState<T>;
}
/// Errors returned by [`Cache::get`].
#[derive(Debug)]
pub(crate) enum GetError<T: Types> {
TenantManager(T::TenantManagerError),
PerTimelineStateShutDown,
}
/// Internal type used in [`Cache::get`].
enum RoutingResult<T: Types> {
FastPath(Handle<T>),
@@ -337,7 +345,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetActiveTimelineError> {
) -> Result<Handle<T>, GetError<T>> {
const GET_MAX_RETRIES: usize = 10;
const RETRY_BACKOFF: Duration = Duration::from_millis(100);
let mut attempt = 0;
@@ -348,11 +356,7 @@ impl<T: Types> Cache<T> {
.await
{
Ok(handle) => return Ok(handle),
Err(
e @ GetActiveTimelineError::Tenant(GetActiveTenantError::WaitForActiveTimeout {
..
}),
) => {
Err(e) => {
// Retry on tenant manager error to handle tenant split more gracefully
if attempt < GET_MAX_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
@@ -366,7 +370,6 @@ impl<T: Types> Cache<T> {
return Err(e);
}
}
Err(err) => return Err(err),
}
}
}
@@ -385,7 +388,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetActiveTimelineError> {
) -> Result<Handle<T>, GetError<T>> {
// terminates because when every iteration we remove an element from the map
let miss: ShardSelector = loop {
let routing_state = self.shard_routing(timeline_id, shard_selector);
@@ -465,50 +468,60 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetActiveTimelineError> {
let timeline = tenant_manager.resolve(timeline_id, shard_selector).await?;
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
) -> Result<Handle<T>, GetError<T>> {
match tenant_manager.resolve(timeline_id, shard_selector).await {
Ok(timeline) => {
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc = Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
let Some(per_timeline_state) = &mut *lock_guard else {
return Err(GetActiveTimelineError::Timeline(
GetTimelineError::ShuttingDown,
));
};
let replaced = per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc =
Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
{
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
match &mut *lock_guard {
Some(per_timeline_state) => {
let replaced =
per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
}
}
}
None => {
return Err(GetError::PerTimelineStateShutDown);
}
}
}
Ok(handle)
}
Err(e) => Err(GetError::TenantManager(e)),
}
Ok(handle)
}
}
@@ -642,8 +655,7 @@ mod tests {
use pageserver_api::models::ShardParameters;
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::DEFAULT_STRIPE_SIZE;
use utils::id::TenantId;
use utils::shard::{ShardCount, TenantShardId};
use utils::shard::ShardCount;
use utils::sync::gate::GateGuard;
use super::*;
@@ -653,6 +665,7 @@ mod tests {
#[derive(Debug)]
struct TestTypes;
impl Types for TestTypes {
type TenantManagerError = anyhow::Error;
type TenantManager = StubManager;
type Timeline = Entered;
}
@@ -703,48 +716,40 @@ mod tests {
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> Result<Entered, GetActiveTimelineError> {
fn enter_gate(
timeline: &StubTimeline,
) -> Result<Arc<GateGuard>, GetActiveTimelineError> {
Ok(Arc::new(timeline.gate.enter().map_err(|_| {
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
})?))
}
) -> anyhow::Result<Entered> {
for timeline in &self.shards {
if timeline.id == timeline_id {
let enter_gate = || {
let gate_guard = timeline.gate.enter()?;
let gate_guard = Arc::new(gate_guard);
anyhow::Ok(gate_guard)
};
match &shard_selector {
ShardSelector::Zero if timeline.shard.is_shard_zero() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate(timeline)?,
gate_guard: enter_gate()?,
});
}
ShardSelector::Zero => continue,
ShardSelector::Page(key) if timeline.shard.is_key_local(key) => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate(timeline)?,
gate_guard: enter_gate()?,
});
}
ShardSelector::Page(_) => continue,
ShardSelector::Known(idx) if idx == &timeline.shard.shard_index() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate(timeline)?,
gate_guard: enter_gate()?,
});
}
ShardSelector::Known(_) => continue,
}
}
}
Err(GetActiveTimelineError::Timeline(
GetTimelineError::NotFound {
tenant_id: TenantShardId::unsharded(TenantId::from([0; 16])),
timeline_id,
},
))
anyhow::bail!("not found")
}
}

View File

@@ -48,6 +48,8 @@ DATA = \
neon--1.3--1.4.sql \
neon--1.4--1.5.sql \
neon--1.5--1.6.sql \
neon--1.6--1.7.sql \
neon--1.7--1.6.sql \
neon--1.6--1.5.sql \
neon--1.5--1.4.sql \
neon--1.4--1.3.sql \

View File

@@ -260,7 +260,7 @@ typedef struct PrefetchState
/* the buffers */
prfh_hash *prf_hash;
int max_shard_no;
int max_unflushed_shard_no;
/* Mark shards involved in prefetch */
uint8 shard_bitmap[(MAX_SHARDS + 7)/8];
PrefetchRequest prf_buffer[]; /* prefetch buffers */
@@ -300,6 +300,7 @@ static void prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_
static bool prefetch_wait_for(uint64 ring_index);
static void prefetch_cleanup_trailing_unused(void);
static inline void prefetch_set_unused(uint64 ring_index);
static bool prefetch_flush_requests(void);
static bool neon_prefetch_response_usable(neon_request_lsns *request_lsns,
PrefetchRequest *slot);
@@ -469,13 +470,26 @@ communicator_prefetch_pump_state(void)
{
START_PREFETCH_RECEIVE_WORK();
if (MyPState->ring_receive == MyPState->ring_flush && MyPState->ring_flush < MyPState->ring_unused)
{
/*
* Flush request to avoid requests pending for arbitrary long time,
* pinning LSN and holding GC at PS.
*/
if (!prefetch_flush_requests())
{
END_PREFETCH_RECEIVE_WORK();
return;
}
}
while (MyPState->ring_receive != MyPState->ring_flush)
{
NeonResponse *response;
PrefetchRequest *slot;
MemoryContext old;
uint64 my_ring_index = MyPState->ring_receive;
slot = GetPrfSlot(MyPState->ring_receive);
slot = GetPrfSlot(my_ring_index);
old = MemoryContextSwitchTo(MyPState->errctx);
response = page_server->try_receive(slot->shard_no);
@@ -489,12 +503,12 @@ communicator_prefetch_pump_state(void)
/* The slot should still be valid */
if (slot->status != PRFS_REQUESTED ||
slot->response != NULL ||
slot->my_ring_index != MyPState->ring_receive)
slot->my_ring_index != my_ring_index)
{
neon_shard_log(slot->shard_no, PANIC,
"Incorrect prefetch slot state after receive: status=%d response=%p my=" UINT64_FORMAT " receive=" UINT64_FORMAT "",
slot->status, slot->response,
slot->my_ring_index, MyPState->ring_receive);
slot->my_ring_index, my_ring_index);
}
/* update prefetch state */
MyPState->n_responses_buffered += 1;
@@ -522,6 +536,19 @@ communicator_prefetch_pump_state(void)
END_PREFETCH_RECEIVE_WORK();
if (RecoveryInProgress())
{
/*
* Update backend's min in-flight prefetch LSN.
*/
XLogRecPtr min_backend_prefetch_lsn = last_replay_lsn != InvalidXLogRecPtr ? last_replay_lsn : GetXLogReplayRecPtr(NULL);
for (uint64_t ring_index = MyPState->ring_receive; ring_index < MyPState->ring_unused; ring_index++)
{
PrefetchRequest* slot = GetPrfSlot(ring_index);
min_backend_prefetch_lsn = Min(slot->request_lsns.request_lsn, min_backend_prefetch_lsn);
}
MIN_BACKEND_REQUEST_LSN = min_backend_prefetch_lsn;
}
communicator_reconfigure_timeout_if_needed();
}
@@ -561,7 +588,7 @@ readahead_buffer_resize(int newsize, void *extra)
newPState->ring_last = newsize;
newPState->ring_unused = newsize;
newPState->ring_receive = newsize;
newPState->max_shard_no = MyPState->max_shard_no;
newPState->max_unflushed_shard_no = MyPState->max_unflushed_shard_no;
memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap));
/*
@@ -661,6 +688,7 @@ consume_prefetch_responses(void)
{
if (MyPState->ring_receive < MyPState->ring_unused)
prefetch_wait_for(MyPState->ring_unused - 1);
/*
* We know for sure we're not working on any prefetch pages after
* this.
@@ -690,7 +718,7 @@ prefetch_cleanup_trailing_unused(void)
static bool
prefetch_flush_requests(void)
{
for (shardno_t shard_no = 0; shard_no < MyPState->max_shard_no; shard_no++)
for (shardno_t shard_no = 0; shard_no < MyPState->max_unflushed_shard_no; shard_no++)
{
if (BITMAP_ISSET(MyPState->shard_bitmap, shard_no))
{
@@ -699,7 +727,8 @@ prefetch_flush_requests(void)
BITMAP_CLR(MyPState->shard_bitmap, shard_no);
}
}
MyPState->max_shard_no = 0;
MyPState->max_unflushed_shard_no = 0;
MyPState->ring_flush = MyPState->ring_unused;
return true;
}
@@ -723,7 +752,6 @@ prefetch_wait_for(uint64 ring_index)
{
if (!prefetch_flush_requests())
return false;
MyPState->ring_flush = MyPState->ring_unused;
}
Assert(MyPState->ring_unused > ring_index);
@@ -802,6 +830,7 @@ prefetch_read(PrefetchRequest *slot)
old = MemoryContextSwitchTo(MyPState->errctx);
response = (NeonResponse *) page_server->receive(shard_no);
MemoryContextSwitchTo(old);
if (response)
{
check_getpage_response(slot, response);
@@ -1010,11 +1039,16 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
Assert(mySlotNo == MyPState->ring_unused);
if (force_request_lsns)
{
slot->request_lsns = *force_request_lsns;
}
else
{
neon_get_request_lsns(BufTagGetNRelFileInfo(slot->buftag),
slot->buftag.forkNum, slot->buftag.blockNum,
&slot->request_lsns, 1);
last_replay_lsn = InvalidXLogRecPtr;
}
request.hdr.lsn = slot->request_lsns.request_lsn;
request.hdr.not_modified_since = slot->request_lsns.not_modified_since;
@@ -1033,7 +1067,7 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
MyPState->n_unused -= 1;
MyPState->ring_unused += 1;
BITMAP_SET(MyPState->shard_bitmap, slot->shard_no);
MyPState->max_shard_no = Max(slot->shard_no+1, MyPState->max_shard_no);
MyPState->max_unflushed_shard_no = Max(slot->shard_no+1, MyPState->max_unflushed_shard_no);
/* update slot state */
slot->status = PRFS_REQUESTED;
@@ -1041,6 +1075,25 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
Assert(!found);
}
/*
* Check that returned page LSN is consistent with request lsns
*/
static void
check_page_lsn(NeonGetPageResponse* resp)
{
if (neon_protocol_version < 3) /* no information to check */
return;
if (PageGetLSN(resp->page) > resp->req.hdr.not_modified_since)
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than last modified LSN %X/%08X",
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
LSN_FORMAT_ARGS(resp->req.hdr.not_modified_since));
if (PageGetLSN(resp->page) > resp->req.hdr.lsn)
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than request LSN %X/%08X",
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
LSN_FORMAT_ARGS(resp->req.hdr.lsn));
}
/*
* Lookup of already received prefetch requests. Only already received responses matching required LSNs are accepted.
* Present pages are marked in "mask" bitmap and total number of such pages is returned.
@@ -1064,7 +1117,7 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
for (int i = 0; i < nblocks; i++)
{
PrfHashEntry *entry;
NeonGetPageResponse* resp;
hashkey.buftag.blockNum = blocknum + i;
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
@@ -1097,8 +1150,9 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
continue;
}
Assert(slot->response->tag == T_NeonGetPageResponse); /* checked by check_getpage_response when response was assigned to the slot */
memcpy(buffers[i], ((NeonGetPageResponse*)slot->response)->page, BLCKSZ);
resp = (NeonGetPageResponse*)slot->response;
check_page_lsn(resp);
memcpy(buffers[i], resp->page, BLCKSZ);
/*
* With lfc_store_prefetch_result=true prefetch result is stored in LFC in prefetch_pump_state when response is received
@@ -1391,7 +1445,6 @@ Retry:
*/
goto Retry;
}
MyPState->ring_flush = MyPState->ring_unused;
}
return last_ring_index;
@@ -1461,10 +1514,12 @@ page_server_request(void const *req)
MyNeonCounters->pageserver_open_requests--;
} while (resp == NULL);
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
last_replay_lsn = InvalidXLogRecPtr;
}
PG_CATCH();
{
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
last_replay_lsn = InvalidXLogRecPtr;
/* Nothing should cancel disconnect: we should not leave connection in opaque state */
HOLD_INTERRUPTS();
page_server->disconnect(shard_no);
@@ -1864,6 +1919,13 @@ nm_to_string(NeonMessage *msg)
return s.data;
}
static void
reset_min_request_lsn(int code, Datum arg)
{
if (MyProcNumber != -1)
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
}
/*
* communicator_init() -- Initialize per-backend private state
*/
@@ -1875,6 +1937,8 @@ communicator_init(void)
if (MyPState != NULL)
return;
before_shmem_exit(reset_min_request_lsn, 0);
/*
* Sanity check that theperf counters array is sized correctly. We got
* this wrong once, and the formula for max number of backends and aux
@@ -1884,7 +1948,7 @@ communicator_init(void)
* the check here. That's OK, we don't expect the logic to change in old
* releases.
*/
#if PG_VERSION_NUM>=150000
#if PG_MAJORVERSION_NUM >= 15
if (MyNeonCounters >= &neon_per_backend_counters_shared[NUM_NEON_PERF_COUNTER_SLOTS])
elog(ERROR, "MyNeonCounters points past end of array");
#endif
@@ -2223,6 +2287,7 @@ Retry:
case T_NeonGetPageResponse:
{
NeonGetPageResponse* getpage_resp = (NeonGetPageResponse *) resp;
check_page_lsn(getpage_resp);
memcpy(buffer, getpage_resp->page, BLCKSZ);
/*
@@ -2499,12 +2564,30 @@ communicator_reconfigure_timeout_if_needed(void)
!AmPrewarmWorker && /* do not pump prefetch state in prewarm worker */
readahead_getpage_pull_timeout_ms > 0;
if (!needs_set && MIN_BACKEND_REQUEST_LSN != InvalidXLogRecPtr)
{
if (last_replay_lsn == InvalidXLogRecPtr)
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
else
needs_set = true; /* Can not reset MIN_BACKEND_REQUEST_LSN now, have to do it later */
}
if (needs_set != timeout_set)
{
/* The background writer doens't (shouldn't) read any pages */
Assert(!AmBackgroundWriterProcess());
/* The checkpointer doens't (shouldn't) read any pages */
Assert(!AmCheckpointerProcess());
/*
* The background writer/checkpointer doens't (shouldn't) read any pages.
* And definitely they should not run on replica.
* The only case when we can get here is replica promotion.
*/
if (AmBackgroundWriterProcess() || AmCheckpointerProcess())
{
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
if (timeout_set)
{
disable_timeout(PS_TIMEOUT_ID, false);
timeout_set = false;
}
return;
}
if (unlikely(PS_TIMEOUT_ID == 0))
{
@@ -2537,14 +2620,6 @@ communicator_reconfigure_timeout_if_needed(void)
static void
pagestore_timeout_handler(void)
{
#if PG_MAJORVERSION_NUM <= 14
/*
* PG14: Setting a repeating timeout is not possible, so we signal here
* that the timeout has already been reset, and by telling the system
* that system will re-schedule it later if we need to.
*/
timeout_set = false;
#endif
timeout_signaled = true;
InterruptPending = true;
}
@@ -2564,6 +2639,14 @@ communicator_processinterrupts(void)
if (!readpage_reentrant_guard && readahead_getpage_pull_timeout_ms > 0)
communicator_prefetch_pump_state();
#if PG_MAJORVERSION_NUM <= 14
/*
* PG14: Setting a repeating timeout is not possible, so we signal here
* that the timeout has already been reset, and by telling the system
* that system will re-schedule it later if we need to.
*/
timeout_set = false;
#endif
timeout_signaled = false;
communicator_reconfigure_timeout_if_needed();
}
@@ -2573,3 +2656,28 @@ communicator_processinterrupts(void)
return prev_interrupt_cb();
}
PG_FUNCTION_INFO_V1(neon_communicator_min_inflight_request_lsn);
Datum
neon_communicator_min_inflight_request_lsn(PG_FUNCTION_ARGS)
{
if (RecoveryInProgress())
{
/* Do not hold GC for primary */
PG_RETURN_INT64(UINT64_MAX);
}
else
{
XLogRecPtr min_lsn = GetXLogReplayRecPtr(NULL);
size_t n_procs = ProcGlobal->allProcCount;
for (size_t i = 0; i < n_procs; i++)
{
if (neon_per_backend_counters_shared[i].min_request_lsn != InvalidXLogRecPtr)
{
min_lsn = Min(min_lsn, neon_per_backend_counters_shared[i].min_request_lsn);
}
}
PG_RETURN_INT64(min_lsn);
}
}

View File

@@ -1832,46 +1832,125 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/*
* Return metrics about the LFC.
*
* The return format is a palloc'd array of LfcStatsEntrys. The size
* of the returned array is returned in *num_entries.
*/
LfcStatsEntry *
lfc_get_stats(size_t *num_entries)
typedef struct
{
LfcStatsEntry *entries;
size_t n = 0;
TupleDesc tupdesc;
} NeonGetStatsCtx;
#define MAX_ENTRIES 10
entries = palloc(sizeof(LfcStatsEntry) * MAX_ENTRIES);
#define NUM_NEON_GET_STATS_COLS 2
entries[n++] = (LfcStatsEntry) {"file_cache_chunk_size_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_blocks_per_chunk : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_misses", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->misses : 0};
entries[n++] = (LfcStatsEntry) {"file_cache_hits", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->hits : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_writes", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->writes : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_size", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->size : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_evicted_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->evicted_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_limit", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->limit : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_chunks_pinned", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->pinned : 0 };
Assert(n <= MAX_ENTRIES);
#undef MAX_ENTRIES
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
NeonGetStatsCtx *fctx;
MemoryContext oldcontext;
TupleDesc tupledesc;
Datum result;
HeapTuple tuple;
char const *key;
uint64 value = 0;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
*num_entries = n;
return entries;
if (SRF_IS_FIRSTCALL())
{
funcctx = SRF_FIRSTCALL_INIT();
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (NeonGetStatsCtx *) palloc(sizeof(NeonGetStatsCtx));
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(NUM_NEON_GET_STATS_COLS);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "lfc_key",
TEXTOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "lfc_value",
INT8OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
}
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = (NeonGetStatsCtx *) funcctx->user_fctx;
switch (funcctx->call_cntr)
{
case 0:
key = "file_cache_misses";
if (lfc_ctl)
value = lfc_ctl->misses;
break;
case 1:
key = "file_cache_hits";
if (lfc_ctl)
value = lfc_ctl->hits;
break;
case 2:
key = "file_cache_used";
if (lfc_ctl)
value = lfc_ctl->used;
break;
case 3:
key = "file_cache_writes";
if (lfc_ctl)
value = lfc_ctl->writes;
break;
case 4:
key = "file_cache_size";
if (lfc_ctl)
value = lfc_ctl->size;
break;
case 5:
key = "file_cache_used_pages";
if (lfc_ctl)
value = lfc_ctl->used_pages;
break;
case 6:
key = "file_cache_evicted_pages";
if (lfc_ctl)
value = lfc_ctl->evicted_pages;
break;
case 7:
key = "file_cache_limit";
if (lfc_ctl)
value = lfc_ctl->limit;
break;
case 8:
key = "file_cache_chunk_size_pages";
value = lfc_blocks_per_chunk;
break;
case 9:
key = "file_cache_chunks_pinned";
if (lfc_ctl)
value = lfc_ctl->pinned;
break;
default:
SRF_RETURN_DONE(funcctx);
}
values[0] = PointerGetDatum(cstring_to_text(key));
nulls[0] = false;
if (lfc_ctl)
{
nulls[1] = false;
values[1] = Int64GetDatum(value);
}
else
nulls[1] = true;
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
}
@@ -1879,86 +1958,193 @@ lfc_get_stats(size_t *num_entries)
* Function returning data from the local file cache
* relation node/tablespace/database/blocknum and access_counter
*/
LocalCachePagesRec *
lfc_local_cache_pages(size_t *num_entries)
PG_FUNCTION_INFO_V1(local_cache_pages);
/*
* Record structure holding the to be exposed cache data.
*/
typedef struct
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
size_t n_pages;
size_t n;
LocalCachePagesRec *result;
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
if (!lfc_ctl)
{
*num_entries = 0;
return NULL;
}
/*
* Function context for data persisting over repeated calls.
*/
typedef struct
{
TupleDesc tupdesc;
LocalCachePagesRec *record;
} LocalCachePagesContext;
LWLockAcquire(lfc_lock, LW_SHARED);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
/* Count the pages first */
n_pages = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
#define NUM_LOCALCACHE_PAGES_ELEM 7
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
Datum result;
MemoryContext oldcontext;
LocalCachePagesContext *fctx; /* User function context. */
TupleDesc tupledesc;
TupleDesc expected_tupledesc;
HeapTuple tuple;
if (SRF_IS_FIRSTCALL())
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
funcctx = SRF_FIRSTCALL_INIT();
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (LocalCachePagesContext *) palloc(sizeof(LocalCachePagesContext));
/*
* To smoothly support upgrades from version 1.0 of this extension
* transparently handle the (non-)existence of the pinning_backends
* column. We unfortunately have to get the result type for that... -
* we can't use the result type determined by the function definition
* without potentially crashing when somebody uses the old (or even
* wrong) function definition though.
*/
if (get_call_result_type(fcinfo, NULL, &expected_tupledesc) != TYPEFUNC_COMPOSITE)
neon_log(ERROR, "return type must be a row type");
if (expected_tupledesc->natts != NUM_LOCALCACHE_PAGES_ELEM)
neon_log(ERROR, "incorrect number of output arguments");
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(expected_tupledesc->natts);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "pageoffs",
INT8OID, -1, 0);
#if PG_MAJORVERSION_NUM < 16
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenode",
OIDOID, -1, 0);
#else
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenumber",
OIDOID, -1, 0);
#endif
TupleDescInitEntry(tupledesc, (AttrNumber) 3, "reltablespace",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 4, "reldatabase",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 5, "relforknumber",
INT2OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 6, "relblocknumber",
INT8OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 7, "accesscount",
INT4OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
if (lfc_ctl)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
LWLockAcquire(lfc_lock, LW_SHARED);
if (n_pages == 0)
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
result = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
/*
* Scan through all the cache entries, saving the relevant fields
* in the result structure.
*/
n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
if (LFC_ENABLED())
{
if (GET_STATE(entry, i) == AVAILABLE)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
result[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
result[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
result[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
result[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
result[n].forknum = entry->key.forkNum;
result[n].blocknum = entry->key.blockNum + i;
result[n].accesscount = entry->access_count;
n += 1;
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
}
}
Assert(n_pages == n);
LWLockRelease(lfc_lock);
fctx->record = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
*num_entries = n_pages;
return result;
/* Set max calls and remember the user function context. */
funcctx->max_calls = n_pages;
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
if (n_pages != 0)
{
/*
* Scan through all the cache entries, saving the relevant fields
* in the fctx->record structure.
*/
uint32 n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}
Assert(n_pages == n);
}
if (lfc_ctl)
LWLockRelease(lfc_lock);
}
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = funcctx->user_fctx;
if (funcctx->call_cntr < funcctx->max_calls)
{
uint32 i = funcctx->call_cntr;
Datum values[NUM_LOCALCACHE_PAGES_ELEM];
bool nulls[NUM_LOCALCACHE_PAGES_ELEM] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) fctx->record[i].pageoffs);
values[1] = ObjectIdGetDatum(fctx->record[i].relfilenode);
values[2] = ObjectIdGetDatum(fctx->record[i].reltablespace);
values[3] = ObjectIdGetDatum(fctx->record[i].reldatabase);
values[4] = ObjectIdGetDatum(fctx->record[i].forknum);
values[5] = Int64GetDatum((int64) fctx->record[i].blocknum);
values[6] = Int32GetDatum(fctx->record[i].accesscount);
/* Build and return the tuple. */
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
}
else
SRF_RETURN_DONE(funcctx);
}
/*
* Internal implementation of the approximate_working_set_size_seconds()
* function.

View File

@@ -47,26 +47,6 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk
extern FileCacheState* lfc_get_state(size_t max_entries);
extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers);
typedef struct LfcStatsEntry
{
const char *metric_name;
bool isnull;
uint64 value;
} LfcStatsEntry;
extern LfcStatsEntry *lfc_get_stats(size_t *num_entries);
typedef struct
{
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
extern LocalCachePagesRec *lfc_local_cache_pages(size_t *num_entries);
extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset);

View File

@@ -0,0 +1,3 @@
create function neon_communicator_min_inflight_request_lsn() returns pg_catalog.pg_lsn
AS 'MODULE_PATHNAME', 'neon_communicator_min_inflight_request_lsn'
LANGUAGE C;

View File

@@ -0,0 +1 @@
drop function neon_communicator_min_inflight_request_lsn();

View File

@@ -625,15 +625,11 @@ _PG_init(void)
ExecutorEnd_hook = neon_ExecutorEnd;
}
/* Various functions exposed at SQL level */
PG_FUNCTION_INFO_V1(pg_cluster_size);
PG_FUNCTION_INFO_V1(backpressure_lsns);
PG_FUNCTION_INFO_V1(backpressure_throttling_time);
PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds);
PG_FUNCTION_INFO_V1(approximate_working_set_size);
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
PG_FUNCTION_INFO_V1(local_cache_pages);
Datum
pg_cluster_size(PG_FUNCTION_ARGS)
@@ -708,76 +704,6 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc);
}
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
#define NUM_NEON_GET_STATS_COLS 2
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LfcStatsEntry *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_get_stats() does all the heavy lifting */
entries = lfc_get_stats(&num_entries);
/* Convert the LfcStatsEntrys to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LfcStatsEntry *entry = &entries[i];
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
values[0] = CStringGetTextDatum(entry->metric_name);
nulls[0] = false;
values[1] = Int64GetDatum(entry->isnull ? 0 : entry->value);
nulls[1] = entry->isnull;
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_NEON_GET_STATS_COLS
}
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
#define NUM_LOCALCACHE_PAGES_COLS 7
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LocalCachePagesRec *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_local_cache_pages() does all the heavy lifting */
entries = lfc_local_cache_pages(&num_entries);
/* Convert the LocalCachePagesRec structs to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LocalCachePagesRec *entry = &entries[i];
Datum values[NUM_LOCALCACHE_PAGES_COLS];
bool nulls[NUM_LOCALCACHE_PAGES_COLS] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) entry->pageoffs);
values[1] = ObjectIdGetDatum(entry->relfilenode);
values[2] = ObjectIdGetDatum(entry->reltablespace);
values[3] = ObjectIdGetDatum(entry->reldatabase);
values[4] = ObjectIdGetDatum(entry->forknum);
values[5] = Int64GetDatum((int64) entry->blocknum);
values[6] = Int32GetDatum(entry->accesscount);
/* Build and return the tuple. */
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_LOCALCACHE_PAGES_COLS
}
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.

View File

@@ -42,7 +42,6 @@ NeonPerfCountersShmemRequest(void)
}
void
NeonPerfCountersShmemInit(void)
{

View File

@@ -154,6 +154,11 @@ typedef struct
* Histogram of query execution time.
*/
QTHistogramData query_time_hist;
/*
* Minimal LSN of in-fligth request requests
*/
XLogRecPtr min_request_lsn;
} neon_per_backend_counters;
/* Pointer to the shared memory array of neon_per_backend_counters structs */
@@ -169,6 +174,12 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
/*
* Backend-local minimal in-flight request LSN.
* We store it in neon_per_backend_counters_shared and not in separate array to minimize false cache sharing
*/
#define MIN_BACKEND_REQUEST_LSN MyNeonCounters->min_request_lsn
extern void inc_getpage_wait(uint64 latency);
extern void inc_page_cache_read_wait(uint64 latency);
extern void inc_page_cache_write_wait(uint64 latency);

View File

@@ -243,6 +243,7 @@ extern char *neon_timeline;
extern char *neon_tenant;
extern int32 max_cluster_size;
extern int neon_protocol_version;
extern XLogRecPtr last_replay_lsn;
extern shardno_t get_shard_number(BufferTag* tag);

View File

@@ -96,6 +96,8 @@ typedef enum
int debug_compare_local;
XLogRecPtr last_replay_lsn;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -159,7 +161,7 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno,
page_std);
}
return ProcLastRecPtr;
return GetXLogInsertRecPtr();
}
#endif /* PG_MAJORVERSION_NUM >= 17 */
@@ -588,6 +590,17 @@ neon_get_request_lsns(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
/* Request the page at the end of the last fully replayed LSN. */
XLogRecPtr replay_lsn = GetXLogReplayRecPtr(NULL);
if (MIN_BACKEND_REQUEST_LSN == InvalidXLogRecPtr)
{
/* mark the backend's replay_lsn as "we have a request ongoing", blocking the expiration of any current LSN */
MIN_BACKEND_REQUEST_LSN = replay_lsn;
/* make sure memory operations are in correct order, even in concurrent systems */
pg_memory_barrier();
/* get the current LSN to register */
replay_lsn = GetXLogReplayRecPtr(NULL);
MIN_BACKEND_REQUEST_LSN = replay_lsn;
}
last_replay_lsn = replay_lsn;
for (int i = 0; i < nblocks; i++)
{
neon_request_lsns *result = &output[i];

View File

@@ -107,7 +107,6 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
# uncomment this to use the real subzero-core crate
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
# this is a stub for the subzero-core crate

View File

@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::sasl;
use crate::stream::{self, Stream};
@@ -25,15 +25,13 @@ pub(crate) async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
role,
pool: config.scram_thread_pool.clone(),
pool: config.thread_pool.clone(),
},
);
let auth_outcome = {

View File

@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
@@ -273,11 +273,9 @@ async fn authenticate_with_secret(
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_outcome =
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
.await?;
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -501,7 +499,7 @@ mod tests {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
scram_thread_pool: ThreadPool::new(1),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
@@ -46,7 +46,6 @@ pub(crate) struct PasswordHack;
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) role: RoleNameInt,
pub(crate) secret: AuthSecret,
}
@@ -112,7 +111,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
self.state.role,
password,
self.state.secret,
)
@@ -167,15 +165,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome =
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -29,7 +29,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ServiceInfo};
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
@@ -114,6 +114,8 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");
debug!("Build_tag: {BUILD_TAG}");
@@ -282,7 +284,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
scram_thread_pool: ThreadPool::new(0),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -26,7 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ServiceInfo};
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
@@ -80,6 +80,8 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args
.get_one::<String>("dest")

View File

@@ -617,12 +617,7 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::get()
.proxy
.scram_pool
.0
.set(thread_pool.metrics.clone())
.ok();
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
@@ -695,7 +690,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
scram_thread_pool: thread_pool,
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,

View File

@@ -8,7 +8,6 @@ use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
use postgres_client::connect_raw::StartupStream;
use postgres_client::error::SqlState;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use thiserror::Error;
@@ -23,7 +22,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::connect_compute::TlsNegotiation;
@@ -66,13 +65,12 @@ impl UserFacingError for PostgresError {
}
impl ReportableError for PostgresError {
fn get_error_kind(&self) -> ErrorKind {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
PostgresError::Postgres(err) => match err.as_db_error() {
Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
Some(_) => ErrorKind::Postgres,
None => ErrorKind::Compute,
},
PostgresError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
}
}
}
@@ -112,9 +110,9 @@ impl UserFacingError for ConnectionError {
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> ErrorKind {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::TlsError(_) => ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]

View File

@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram;
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
@@ -75,7 +75,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,

View File

@@ -5,7 +5,6 @@ use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
StaticLabelSet,
};
use measured::metric::group::Encoding;
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
@@ -19,10 +18,10 @@ use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
#[derive(MetricGroup)]
#[metric(new())]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new())]
#[metric(init = ProxyMetrics::new(thread_pool))]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
@@ -35,27 +34,34 @@ pub struct Metrics {
pub cache: CacheMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
#[track_caller]
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
SELF.set(metrics)
.ok()
.expect("proxy metrics must not be installed more than once");
}
pub fn get() -> &'static Self {
static SELF: OnceLock<Metrics> = OnceLock::new();
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
SELF.get_or_init(|| {
let mut metrics = Metrics::new();
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
metrics
})
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
}
}
#[derive(MetricGroup)]
#[metric(new())]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -128,9 +134,6 @@ pub struct ProxyMetrics {
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of SHA 256 rounds executed.
pub sha_rounds: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
@@ -148,25 +151,8 @@ pub struct ProxyMetrics {
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
}
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
impl<T> Default for OnceLockWrapper<T> {
fn default() -> Self {
Self(OnceLock::new())
}
}
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if let Some(inner) = self.0.get() {
inner.collect_group_into(enc)?;
}
Ok(())
}
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
}
#[derive(MetricGroup)]
@@ -567,6 +553,14 @@ impl From<bool> for Bool {
}
}
#[derive(LabelGroup)]
#[label(set = InvalidEndpointsSet)]
pub struct InvalidEndpointsGroup {
pub protocol: Protocol,
pub rejected: Bool,
pub outcome: ConnectOutcome,
}
#[derive(LabelGroup)]
#[label(set = RetriesMetricSet)]
pub struct RetriesMetricGroup {
@@ -733,7 +727,6 @@ pub enum CacheKind {
ProjectInfoEndpoints,
ProjectInfoRoles,
Schema,
Pbkdf2,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

View File

@@ -1,84 +0,0 @@
use tokio::time::Instant;
use zeroize::Zeroize as _;
use super::pbkdf2;
use crate::cache::Cached;
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
impl Cache for Pbkdf2Cache {
type Key = (EndpointIdInt, RoleNameInt);
type Value = Pbkdf2CacheEntry;
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
self.0.invalidate(info);
}
}
/// To speed up password hashing for more active customers, we store the tail results of the
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
/// to determine the final result.
///
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
/// While both are cached in memory, given they're in different locations is makes it much
/// harder to exploit, even if any such memory exploit exists in proxy.
#[derive(Clone)]
pub struct Pbkdf2CacheEntry {
/// corresponds to [`super::ServerSecret::cached_at`]
pub(super) cached_from: Instant,
pub(super) suffix: pbkdf2::Block,
}
impl Drop for Pbkdf2CacheEntry {
fn drop(&mut self) {
self.suffix.zeroize();
}
}
impl Pbkdf2Cache {
pub fn new() -> Self {
const SIZE: u64 = 100;
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
let builder = moka::sync::Cache::builder()
.name("pbkdf2")
.max_capacity(SIZE)
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
.time_to_live(TTL);
Metrics::get()
.cache
.capacity
.set(CacheKind::Pbkdf2, SIZE as i64);
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
Self(builder.build())
}
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
count_cache_insert(CacheKind::Pbkdf2);
self.0.insert((endpoint, role), value);
}
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
}
pub fn get_entry(
&self,
endpoint: EndpointIdInt,
role: RoleNameInt,
) -> Option<CachedPbkdf2<'_>> {
self.get(endpoint, role).map(|value| Cached {
token: Some((self, (endpoint, role))),
value,
})
}
}

View File

@@ -4,8 +4,10 @@ use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use tracing::{debug, trace};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use super::ScramKey;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
@@ -13,10 +15,8 @@ use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::{ScramKey, pbkdf2};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
use crate::scram::cache::Pbkdf2CacheEntry;
/// The only channel binding mode we currently support.
#[derive(Debug)]
@@ -77,113 +77,46 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> pbkdf2::Block {
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
}
/// For cleartext flow, we need to derive the client key to
/// 1. authenticate the client.
/// 2. authenticate with compute.
pub(crate) async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
if secret.iterations > CACHED_ROUNDS {
exchange_with_cache(pool, endpoint, role, secret, password).await
} else {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
Ok(validate_pbkdf2(secret, &hash))
}
}
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
/// which is not enough by itself to perform an offline brute force.
async fn exchange_with_cache(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
debug_assert!(
secret.iterations > CACHED_ROUNDS,
"we should not cache password data if there isn't enough rounds needed"
);
// compute the prefix of the pbkdf2 output.
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
// hot path: let's check the threadpool cache
if secret.cached_at == entry.cached_from {
// cache is valid. compute the full hash by adding the prefix to the suffix.
let mut hash = prefix;
pbkdf2::xor_assign(&mut hash, &entry.suffix);
let outcome = validate_pbkdf2(secret, &hash);
if matches!(outcome, sasl::Outcome::Success(_)) {
trace!("password validated from cache");
}
return Ok(outcome);
}
// cached key is no longer valid.
debug!("invalidating cached password");
entry.invalidate();
}
// slow path: full password hash.
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let outcome = validate_pbkdf2(secret, &hash);
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(_) => return Ok(outcome),
};
trace!("storing cached password");
// time to cache, compute the suffix by subtracting the prefix from the hash.
let mut suffix = hash;
pbkdf2::xor_assign(&mut suffix, &prefix);
pool.cache.insert(
endpoint,
role,
Pbkdf2CacheEntry {
cached_from: secret.cached_at,
suffix,
},
);
Ok(sasl::Outcome::Success(client_key))
}
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
let client_key = super::ScramKey::client_key(&(*hash).into());
if secret.is_password_invalid(&client_key).into() {
sasl::Outcome::Failure("password doesn't match")
Ok(sasl::Outcome::Failure("password doesn't match"))
} else {
sasl::Outcome::Success(client_key)
Ok(sasl::Outcome::Success(client_key))
}
}
const CACHED_ROUNDS: u32 = 16;
impl SaslInitial {
fn transition(
&self,

View File

@@ -1,12 +1,6 @@
//! Tools for client/server/stored key management.
use hmac::Mac as _;
use sha2::Digest as _;
use subtle::ConstantTimeEq;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_KEY_LEN: usize = 32;
@@ -20,12 +14,6 @@ pub(crate) struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Drop for ScramKey {
fn drop(&mut self) {
self.bytes.zeroize();
}
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
@@ -40,26 +28,12 @@ impl ConstantTimeEq for ScramKey {
impl ScramKey {
pub(crate) fn sha256(&self) -> Self {
Metrics::get().proxy.sha_rounds.inc_by(1);
Self {
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
}
super::sha256([self.as_ref()]).into()
}
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
// Prf::new_from_slice will run 2 sha256 rounds.
// Update + Finalize run 2 sha256 rounds.
Metrics::get().proxy.sha_rounds.inc_by(4);
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
prf.update(b"Client Key");
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
client_key.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {

View File

@@ -6,7 +6,6 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod cache;
mod countmin;
mod exchange;
mod key;
@@ -19,8 +18,10 @@ pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -41,13 +42,29 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::sasl::{Mechanism, Step};
use crate::types::{EndpointId, RoleName};
use crate::types::EndpointId;
#[test]
fn snapshot() {
@@ -97,34 +114,23 @@ mod tests {
);
}
async fn check(
pool: &ThreadPool,
scram_secret: &ServerSecret,
password: &[u8],
) -> Result<(), &'static str> {
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let role = RoleName::from("user");
let role = RoleNameInt::from(&role);
let outcome = super::exchange(pool, ep, role, scram_secret, password)
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => Ok(()),
crate::sasl::Outcome::Failure(r) => Err(r),
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
}
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
check(&pool, &scram_secret, client_password.as_bytes())
.await
.unwrap();
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await;
@@ -135,27 +141,4 @@ mod tests {
async fn failure() {
run_round_trip_test("pencil", "eraser").await;
}
#[tokio::test]
#[tracing_test::traced_test]
async fn password_cache() {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build("password").await.unwrap();
// wrong passwords are not added to cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("storing cached password"));
// correct passwords get cached
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("storing cached password"));
// wrong passwords do not match the cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("password validated from cache"));
// correct passwords match the cache
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("password validated from cache"));
}
}

View File

@@ -1,50 +1,25 @@
//! For postgres password authentication, we need to perform a PBKDF2 using
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
use hmac::Mac as _;
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
pub type Prf = hmac::Hmac<sha2::Sha256>;
pub(crate) type Block = GenericArray<u8, U32>;
use hmac::{Hmac, Mac};
use sha2::Sha256;
pub(crate) struct Pbkdf2 {
hmac: Prf,
/// U{r-1} for whatever iteration r we are currently on.
prev: Block,
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
hi: Block,
/// number of iterations left
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
iterations: u32,
}
impl Drop for Pbkdf2 {
fn drop(&mut self) {
self.prev.zeroize();
self.hi.zeroize();
}
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
// key the HMAC and derive the first block in-place
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
// U1 = PRF(Password, Salt + INT_32_BE(i))
// i = 1 since we only need 1 block of output.
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
hmac.update(salt);
hmac.update(&1u32.to_be_bytes());
let init_block = hmac.finalize_reset().into_bytes();
// Prf::new_from_slice will run 2 sha256 rounds.
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(4);
Self {
hmac,
// one iteration spent above
@@ -58,11 +33,7 @@ impl Pbkdf2 {
(self.iterations).clamp(0, 4096)
}
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
/// function that only executes a fixed number of iterations before continuing.
///
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
let Self {
hmac,
prev,
@@ -73,37 +44,25 @@ impl Pbkdf2 {
// only do up to 4096 iterations per turn for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
let next = single_round(hmac, prev);
xor_assign(hi, &next);
*prev = next;
}
hmac.update(prev);
let block = hmac.finalize_reset().into_bytes();
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
*hi_byte ^= b;
}
*prev = block;
}
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready(*hi)
std::task::Poll::Ready((*hi).into())
} else {
std::task::Poll::Pending
}
}
}
#[inline(always)]
pub fn xor_assign(x: &mut Block, y: &Block) {
for (x, &y) in std::iter::zip(x, y) {
*x ^= y;
}
}
#[inline(always)]
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
// Ui = PRF(Password, Ui-1)
prf.update(ui);
prf.finalize_reset().into_bytes()
}
#[cfg(test)]
mod tests {
use pbkdf2::pbkdf2_hmac_array;
@@ -117,11 +76,11 @@ mod tests {
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 60000);
let hash: [u8; 32] = loop {
let hash = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash.into();
break hash;
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);

View File

@@ -3,7 +3,6 @@
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use tokio::time::Instant;
use super::base64_decode_array;
use super::key::ScramKey;
@@ -12,9 +11,6 @@ use super::key::ScramKey;
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ServerSecret {
/// When this secret was cached.
pub(crate) cached_at: Instant,
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
@@ -38,7 +34,6 @@ impl ServerSecret {
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
cached_at: Instant::now(),
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
@@ -59,7 +54,6 @@ impl ServerSecret {
/// See `auth-scram.c : mock_scram_secret` for details.
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
Self {
cached_at: Instant::now(),
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.

View File

@@ -1,10 +1,6 @@
//! Tools for client/server signature management.
use hmac::Mac as _;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
@@ -16,18 +12,15 @@ pub(crate) struct SignatureBuilder<'a> {
impl SignatureBuilder<'_> {
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
// don't know exactly. this is a rough approx
Metrics::get().proxy.sha_rounds.inc_by(8);
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
mac.update(self.client_first_message_bare.as_bytes());
mac.update(b",");
mac.update(self.server_first_message.as_bytes());
mac.update(b",");
mac.update(self.client_final_message_without_proof.as_bytes());
Signature {
bytes: mac.finalize().into_bytes().into(),
}
super::hmac_sha256(key.as_ref(), parts).into()
}
}

View File

@@ -15,8 +15,6 @@ use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::cache::Pbkdf2Cache;
use super::pbkdf2;
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
@@ -25,10 +23,6 @@ use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
pub metrics: Arc<ThreadPoolMetrics>,
// we hash a lot of passwords.
// we keep a cache of partial hashes for faster validation.
pub(super) cache: Pbkdf2Cache,
}
/// How often to reset the sketch values
@@ -74,7 +68,6 @@ impl ThreadPool {
Self {
runtime: Some(runtime),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
cache: Pbkdf2Cache::new(),
}
})
}
@@ -137,7 +130,7 @@ struct JobSpec {
}
impl Future for JobSpec {
type Output = pbkdf2::Block;
type Output = [u8; 32];
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
STATE.with_borrow_mut(|state| {
@@ -173,10 +166,10 @@ impl Future for JobSpec {
}
}
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
impl Future for JobHandle {
type Output = pbkdf2::Block;
type Output = [u8; 32];
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
@@ -210,10 +203,10 @@ mod tests {
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await;
let expected = &[
let expected = [
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual.as_slice(), expected);
assert_eq!(actual, expected);
}
}

View File

@@ -4,7 +4,6 @@ use std::time::Duration;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::error::SqlState;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand_core::OsRng;
use tracing::field::display;
@@ -27,7 +26,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::intern::EndpointIdInt;
use crate::pqproto::StartupMessageParams;
use crate::proxy::{connect_auth, connect_compute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -77,11 +76,9 @@ impl PoolingBackend {
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let role = RoleNameInt::from(&user_info.user);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.scram_thread_pool,
&self.config.authentication_config.thread_pool,
ep,
role,
password,
secret,
)
@@ -460,14 +457,15 @@ impl ReportableError for HttpConnError {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => match p.as_db_error() {
// user provided a wrong database name
Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
// postgres rejected the connection
Some(_) => ErrorKind::Postgres,
// couldn't even reach postgres
None => ErrorKind::Compute,
},
HttpConnError::PostgresConnectionError(p) => {
if p.as_db_error().is_some() {
// postgres rejected the connection
ErrorKind::Postgres
} else {
// couldn't even reach postgres
ErrorKind::Compute
}
}
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,

View File

@@ -192,29 +192,34 @@ pub(crate) async fn handle(
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
if db_error.is_some() && error_kind == ErrorKind::User {
// this error contains too much info, and it's not an error we care about.
if tracing::enabled!(Level::DEBUG) {
debug!(
match &e {
SqlOverHttpError::Postgres(e)
if e.as_db_error().is_some() && error_kind == ErrorKind::User =>
{
// this error contains too much info, and it's not an error we care about.
if tracing::enabled!(Level::DEBUG) {
tracing::debug!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
} else {
tracing::info!(
kind = error_kind.to_metric_label(),
error = "bad query",
"forwarding error to user"
);
}
}
_ => {
tracing::info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
} else {
info!(
kind = error_kind.to_metric_label(),
error = "bad query",
"forwarding error to user"
);
}
} else {
info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
}
json_response(

View File

@@ -102,7 +102,7 @@ pub struct ReportedError {
}
impl ReportedError {
pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),

View File

@@ -12,7 +12,7 @@ use futures::stream::{self, FuturesOrdered};
use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr;
use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo};
use remote_storage::{
DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
};
use safekeeper_api::models::PeerInfo;
use tokio::fs::File;
@@ -607,9 +607,6 @@ pub(crate) async fn copy_partial_segment(
storage.copy_object(source, destination, &cancel).await
}
const WAL_READ_WARN_THRESHOLD: u32 = 2;
const WAL_READ_MAX_RETRIES: u32 = 3;
pub async fn read_object(
storage: &GenericRemoteStorage,
file_path: &RemotePath,
@@ -623,23 +620,12 @@ pub async fn read_object(
byte_start: std::ops::Bound::Included(offset),
..Default::default()
};
// This retry only solves the connect errors: subsequent reads can still fail as this function returns
// a stream.
let download = backoff::retry(
|| async { storage.download(file_path, &opts, &cancel).await },
DownloadError::is_permanent,
WAL_READ_WARN_THRESHOLD,
WAL_READ_MAX_RETRIES,
"download WAL segment",
&cancel,
)
.await
.ok_or_else(|| DownloadError::Cancelled)
.and_then(|x| x)
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")
})?;
let download = storage
.download(file_path, &opts, &cancel)
.await
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")
})?;
let reader = tokio_util::io::StreamReader::new(download.download_stream);

View File

@@ -74,4 +74,4 @@ http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -1,7 +0,0 @@
/// Type of the storage node (pageserver or safekeeper) that we are updating DNS records for. Different types of nodes will have
/// different-looking DNS names in the DNS zone.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum NodeType {
Pageserver,
Safekeeper,
}

View File

@@ -1,433 +0,0 @@
#![allow(dead_code, unused)]
use std::collections::{HashMap, HashSet};
use diesel::Queryable;
use diesel::dsl::min;
use diesel::prelude::*;
use diesel_async::AsyncConnection;
use diesel_async::AsyncPgConnection;
use diesel_async::RunQueryDsl;
use itertools::Itertools;
use pageserver_api::controller_api::SCSafekeeperTimelinesResponse;
use scoped_futures::ScopedFutureExt;
use serde::{Deserialize, Serialize};
use utils::id::{NodeId, TenantId, TimelineId};
use uuid::Uuid;
use crate::hadron_dns::NodeType;
use crate::hadron_requests::NodeConnectionInfo;
use crate::persistence::{DatabaseError, DatabaseResult};
use crate::schema::{hadron_safekeepers, nodes};
use crate::sk_node::SafeKeeperNode;
use std::str::FromStr;
// The Safe Keeper node database representation (for Diesel).
#[derive(
Clone, Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq, AsChangeset,
)]
#[diesel(table_name = crate::schema::hadron_safekeepers)]
pub(crate) struct HadronSafekeeperRow {
pub(crate) sk_node_id: i64,
pub(crate) listen_http_addr: String,
pub(crate) listen_http_port: i32,
pub(crate) listen_pg_addr: String,
pub(crate) listen_pg_port: i32,
}
#[derive(
Clone, Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq, AsChangeset,
)]
#[diesel(table_name = crate::schema::hadron_timeline_safekeepers)]
pub(crate) struct HadronTimelineSafekeeper {
pub(crate) timeline_id: String,
pub(crate) sk_node_id: i64,
pub(crate) legacy_endpoint_id: Option<Uuid>,
}
pub async fn execute_sk_upsert(
conn: &mut AsyncPgConnection,
sk_row: HadronSafekeeperRow,
) -> DatabaseResult<()> {
// SQL:
// INSERT INTO hadron_safekeepers (sk_node_id, listen_http_addr, listen_http_port, listen_pg_addr, listen_pg_port)
// VALUES ($1, $2, $3, $4, $5)
// ON CONFLICT (sk_node_id)
// DO UPDATE SET listen_http_addr = $2, listen_http_port = $3, listen_pg_addr = $4, listen_pg_port = $5;
use crate::schema::hadron_safekeepers::dsl::*;
diesel::insert_into(hadron_safekeepers)
.values(&sk_row)
.on_conflict(sk_node_id)
.do_update()
.set(&sk_row)
.execute(conn)
.await?;
Ok(())
}
// Load all safekeeper nodes and their associated timelines from the meta PG. This query is supposed
// to run only once on HCC startup and is used to construct the SafeKeeperScheduler state. Performs
// scans of the hadron_safekeepers and hadron_timeline_safekeepers tables.
pub async fn scan_safekeepers_and_scheduled_timelines(
conn: &mut AsyncPgConnection,
) -> DatabaseResult<HashMap<NodeId, SafeKeeperNode>> {
use crate::schema::hadron_safekeepers;
use crate::schema::hadron_timeline_safekeepers;
// We first scan the hadron_safekeepers table to constuct the SafeKeeperNode objects. We don't know anything about
// the timelines scheduled to the safekeepers after this step. We then scan the hadron_timeline_safekeepers table
// to populate the data structures in the SafeKeeperNode objects to reflect the timelines scheduled to the safekeepers.
let mut results: HashMap<NodeId, SafeKeeperNode> = hadron_safekeepers::table
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_http_addr,
hadron_safekeepers::listen_http_port,
hadron_safekeepers::listen_pg_addr,
hadron_safekeepers::listen_pg_port,
))
.load::<HadronSafekeeperRow>(conn)
.await?
.into_iter()
.map(|row| {
let sk_node = SafeKeeperNode {
id: NodeId(row.sk_node_id as u64),
listen_http_addr: row.listen_http_addr.clone(),
listen_http_port: row.listen_http_port as u16,
listen_pg_addr: row.listen_pg_addr.clone(),
listen_pg_port: row.listen_pg_port as u16,
legacy_endpoints: HashMap::new(),
timelines: HashSet::new(),
};
(sk_node.id, sk_node)
})
.collect();
let timeline_sk_rows = hadron_timeline_safekeepers::table
.select((
hadron_timeline_safekeepers::sk_node_id,
hadron_timeline_safekeepers::timeline_id,
hadron_timeline_safekeepers::legacy_endpoint_id,
))
.load::<(i64, String, Option<Uuid>)>(conn)
.await?;
for (sk_node_id, timeline_id, legacy_endpoint_id) in timeline_sk_rows {
if let Some(sk_node) = results.get_mut(&NodeId(sk_node_id as u64)) {
let parsed_timeline_id =
TimelineId::from_str(&timeline_id).map_err(|e: hex::FromHexError| {
DatabaseError::Logical(format!("Failed to parse timeline IDs: {e}"))
})?;
sk_node.timelines.insert(parsed_timeline_id);
if let Some(legacy_endpoint_id) = legacy_endpoint_id {
sk_node
.legacy_endpoints
.insert(legacy_endpoint_id, parsed_timeline_id);
}
}
}
Ok(results)
}
// Queries the hadron_timeline_safekeepers table to get the safekeepers assigned to the passed
// timeline. If none are found, persists the input proposed safekeepers to the table and returns
// them.
pub async fn idempotently_persist_or_get_existing_timeline_safekeepers(
conn: &mut AsyncPgConnection,
timeline_id: TimelineId,
safekeepers: &[NodeId],
) -> DatabaseResult<Vec<NodeId>> {
use crate::schema::hadron_timeline_safekeepers;
// Confirm and persist the timeline-safekeeper mapping. If there are existing safekeepers
// assigned to the timeline in the database, treat those as the source of truth.
let existing_safekeepers: Vec<i64> = hadron_timeline_safekeepers::table
.select(hadron_timeline_safekeepers::sk_node_id)
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.load::<i64>(conn)
.await?;
let confirmed_safekeepers: Vec<NodeId> = if existing_safekeepers.is_empty() {
let proposed_safekeeper_endpoint_rows_result: Result<Vec<HadronTimelineSafekeeper>, _> =
safekeepers
.iter()
.map(|sk_node_id| {
i64::try_from(sk_node_id.0).map(|sk_node_id| HadronTimelineSafekeeper {
timeline_id: timeline_id.to_string(),
sk_node_id,
legacy_endpoint_id: None,
})
})
.collect();
let proposed_safekeeper_endpoint_rows =
proposed_safekeeper_endpoint_rows_result.map_err(|e| {
DatabaseError::Logical(format!("Failed to convert safekeeper IDs: {e}"))
})?;
diesel::insert_into(hadron_timeline_safekeepers::table)
.values(&proposed_safekeeper_endpoint_rows)
.execute(conn)
.await?;
safekeepers.to_owned()
} else {
let safekeeper_result: Result<Vec<NodeId>, _> = existing_safekeepers
.into_iter()
.map(|arg0: i64| u64::try_from(arg0).map(NodeId))
.collect();
safekeeper_result
.map_err(|e| DatabaseError::Logical(format!("Failed to convert safekeeper IDs: {e}")))?
};
Ok(confirmed_safekeepers)
}
pub async fn delete_timeline_safekeepers(
conn: &mut AsyncPgConnection,
timeline_id: TimelineId,
) -> DatabaseResult<()> {
use crate::schema::hadron_timeline_safekeepers;
diesel::delete(hadron_timeline_safekeepers::table)
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.execute(conn)
.await?;
Ok(())
}
pub(crate) async fn execute_safekeeper_list_timelines(
conn: &mut AsyncPgConnection,
safekeeper_id: i64,
) -> DatabaseResult<SCSafekeeperTimelinesResponse> {
use crate::schema::hadron_timeline_safekeepers;
use pageserver_api::controller_api::SCSafekeeperTimelinesResponse;
conn.transaction(|conn| {
async move {
let mut sk_timelines = SCSafekeeperTimelinesResponse {
timelines: Vec::new(),
safekeeper_peers: Vec::new(),
};
// Find all timelines <String>
let timeline_ids = hadron_timeline_safekeepers::table
.select(hadron_timeline_safekeepers::timeline_id)
.filter(hadron_timeline_safekeepers::sk_node_id.eq(safekeeper_id))
.load::<String>(conn)
.await
.into_iter()
.flatten()
.collect_vec();
// Find the peers for each timeline. <timeline_id, sk_node_id>
let timeline_peers = hadron_timeline_safekeepers::table
.select((
hadron_timeline_safekeepers::timeline_id,
hadron_timeline_safekeepers::sk_node_id,
))
.filter(hadron_timeline_safekeepers::timeline_id.eq_any(&timeline_ids))
.load::<(String, i64)>(conn)
.await
.into_iter()
.flatten()
.collect_vec();
let mut timeline_peers_map = HashMap::new();
let mut seen = HashSet::new();
let mut unique_sks = Vec::new();
for (timeline_id, sk_node_id) in timeline_peers {
timeline_peers_map
.entry(timeline_id)
.or_insert_with(Vec::new)
.push(sk_node_id);
if seen.insert(sk_node_id) {
unique_sks.push(sk_node_id);
}
}
// Find SK info.
let mut found_sk_nodes = HashSet::new();
hadron_safekeepers::table
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_http_addr,
hadron_safekeepers::listen_http_port,
))
.filter(hadron_safekeepers::sk_node_id.eq_any(&unique_sks))
.load::<(i64, String, i32)>(conn)
.await
.into_iter()
.flatten()
.for_each(|(sk_node_id, listen_http_addr, http_port)| {
found_sk_nodes.insert(sk_node_id);
sk_timelines.safekeeper_peers.push(
pageserver_api::controller_api::TimelineSafekeeperPeer {
node_id: utils::id::NodeId(sk_node_id as u64),
listen_http_addr,
http_port,
},
);
});
// Prepare timeline response.
for timeline_id in timeline_ids {
if !timeline_peers_map.contains_key(&timeline_id) {
continue;
}
let peers = timeline_peers_map.get(&timeline_id).unwrap();
// Check peers exist.
if !peers
.iter()
.all(|sk_node_id| found_sk_nodes.contains(sk_node_id))
{
continue;
}
let timeline = pageserver_api::controller_api::SCSafekeeperTimeline {
timeline_id: TimelineId::from_str(&timeline_id).unwrap(),
peers: peers
.iter()
.map(|sk_node_id| utils::id::NodeId(*sk_node_id as u64))
.collect(),
};
sk_timelines.timelines.push(timeline);
}
Ok(sk_timelines)
}
.scope_boxed()
})
.await
}
/// Stores details about connecting to pageserver and safekeeper nodes for a given tenant and
/// timeline.
pub struct PageserverAndSafekeeperConnectionInfo {
pub pageserver_conn_info: Vec<NodeConnectionInfo>,
pub safekeeper_conn_info: Vec<NodeConnectionInfo>,
}
/// Retrieves the connection information for the pageserver and safekeepers associated with the
/// given tenant and timeline.
pub async fn get_pageserver_and_safekeeper_connection_info(
conn: &mut AsyncPgConnection,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> DatabaseResult<PageserverAndSafekeeperConnectionInfo> {
conn.transaction(|conn| {
async move {
// Fetch details about pageserver, which is associated with the input tenant.
let pageserver_conn_info =
get_pageserver_connection_info(conn, &tenant_id.to_string()).await?;
// Fetch details about safekeepers, which are associated with the input timeline.
let safekeeper_conn_info =
get_safekeeper_connection_info(conn, &timeline_id.to_string()).await?;
Ok(PageserverAndSafekeeperConnectionInfo {
pageserver_conn_info,
safekeeper_conn_info,
})
}
.scope_boxed()
})
.await
}
async fn get_safekeeper_connection_info(
conn: &mut AsyncPgConnection,
timeline_id: &str,
) -> DatabaseResult<Vec<NodeConnectionInfo>> {
use crate::schema::hadron_safekeepers;
use crate::schema::hadron_timeline_safekeepers;
Ok(hadron_timeline_safekeepers::table
.inner_join(
hadron_safekeepers::table
.on(hadron_timeline_safekeepers::sk_node_id.eq(hadron_safekeepers::sk_node_id)),
)
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_pg_addr,
hadron_safekeepers::listen_pg_port,
))
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.load::<(i64, String, i32)>(conn)
.await?
.into_iter()
.map(|(node_id, addr, port)| {
NodeConnectionInfo::new(
NodeType::Safekeeper,
NodeId(node_id as u64),
addr,
port as u16,
)
})
.collect())
}
async fn get_pageserver_connection_info(
conn: &mut AsyncPgConnection,
tenant_id: &str,
) -> DatabaseResult<Vec<NodeConnectionInfo>> {
use crate::schema::tenant_shards;
// When the tenant is being split, it'll contain both old shards and new shards. Until the tenant split is committed,
// we should always use the old shards.
// NOTE: we only support tenant split without tennat merge. Thus shard count could only increase.
let min_shard_count = match tenant_shards::table
.select(min(tenant_shards::shard_count))
.filter(tenant_shards::tenant_id.eq(tenant_id))
.first::<Option<i32>>(conn)
.await
.optional()?
{
Some(Some(count)) => count,
Some(None) => {
// Tenant doesn't exist. It's possible that it was deleted before we got the request.
return Ok(vec![]);
}
None => {
// This is never supposed to happen because `SELECT min()` should always return one row.
return Err(DatabaseError::Logical(format!(
"Unexpected empty query result for min(shard_count) query. Tenant ID {tenant_id}"
)));
}
};
let shards: Vec<NodeConnectionInfo> = nodes::table
.inner_join(
tenant_shards::table.on(nodes::node_id
.nullable()
.eq(tenant_shards::generation_pageserver)),
)
.select((nodes::node_id, nodes::listen_pg_addr, nodes::listen_pg_port))
.filter(tenant_shards::tenant_id.eq(&tenant_id.to_string()))
.order(tenant_shards::shard_number.asc())
.filter(tenant_shards::shard_count.eq(min_shard_count))
.load::<(i64, String, i32)>(conn)
.await?
.into_iter()
.map(|(node_id, addr, port)| {
NodeConnectionInfo::new(
NodeType::Pageserver,
NodeId(node_id as u64),
addr,
port as u16,
)
})
.collect();
if !shards.is_empty() && !shards.len().is_power_of_two() {
return Err(DatabaseError::Logical(format!(
"Tenant {} has unexpected shard count {} (not a power of 2)",
tenant_id,
shards.len()
)));
}
Ok(shards)
}

View File

@@ -1,34 +0,0 @@
use utils::id::NodeId;
use crate::hadron_dns::NodeType;
/// Internal representation of how a compute node should connect to a PS or SK node. HCC uses this struct to
/// construct connection strings that are passed to the compute node via the compute spec. This struct is never
/// serialized or sent over the wire.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct NodeConnectionInfo {
// Type of the node.
node_type: NodeType,
// Node ID. Unique for each node type.
pub(crate) node_id: NodeId,
// The hostname reported by the node when it registers. This is the hostname we store in the meta PG, and is
// typically the k8s cluster DNS name of the node. Note that this may not be resolvable from compute nodes running
// on dblet. For this reason, this hostname is usually not communicated to the compute node. Instead, HCC computes
// a DNS name of the node in the Cloud DNS hosted zone based on `node_type` and `node_id` and advertise the DNS name
// to compute nodes. This hostname here is used as a fallback in tests or other scenarios where we do not have the
// Cloud DNS hosted zone available.
registration_hostname: String,
// The PG wire protocol port on the PS or SK node.
port: u16,
}
impl NodeConnectionInfo {
pub(crate) fn new(node_type: NodeType, node_id: NodeId, hostname: String, port: u16) -> Self {
NodeConnectionInfo {
node_type,
node_id,
registration_hostname: hostname,
port,
}
}
}

View File

@@ -6,9 +6,6 @@ extern crate hyper0 as hyper;
mod auth;
mod background_node_operations;
mod compute_hook;
pub mod hadron_dns;
mod hadron_queries;
pub mod hadron_requests;
pub mod hadron_utils;
mod heartbeater;
pub mod http;
@@ -26,7 +23,6 @@ mod safekeeper_client;
mod scheduler;
mod schema;
pub mod service;
mod sk_node;
mod tenant_shard;
mod timeline_import;

View File

@@ -16,7 +16,7 @@ use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::{Persistence, PersistenceConfig};
use storage_controller::persistence::Persistence;
use storage_controller::service::chaos_injector::ChaosInjector;
use storage_controller::service::feature_flag::FeatureFlagService;
use storage_controller::service::{
@@ -229,15 +229,6 @@ struct Cli {
/// **Feature Flag** Whether the storage controller should act to rectify pageserver-reported local disk loss.
#[arg(long, default_value = "false")]
handle_ps_local_disk_loss: bool,
#[arg(long)]
db_max_connections: Option<u32>,
#[arg(long)]
db_idle_connection_timeout: Option<humantime::Duration>,
#[arg(long)]
db_max_connection_lifetime: Option<humantime::Duration>,
}
enum StrictMode {
@@ -347,7 +338,7 @@ fn main() -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
// We use spawn_blocking for database operations, so require approximately
// as many blocking threads as we will open database connections.
.max_blocking_threads(PersistenceConfig::MAX_CONNECTIONS_DEFAULT as usize)
.max_blocking_threads(Persistence::MAX_CONNECTIONS as usize)
.enable_all()
.build()
.unwrap()
@@ -438,19 +429,6 @@ async fn async_main() -> anyhow::Result<()> {
None
};
let db_idle_connection_timeout: Option<Duration> = args
.db_idle_connection_timeout
.map(humantime::Duration::into);
let db_max_connection_lifetime: Option<Duration> = args
.db_max_connection_lifetime
.map(humantime::Duration::into);
let persistence_config = PersistenceConfig::new(
args.db_max_connections,
db_idle_connection_timeout,
db_max_connection_lifetime,
);
let config = Config {
pageserver_jwt_token: secrets.pageserver_jwt_token,
safekeeper_jwt_token: secrets.safekeeper_jwt_token,
@@ -504,13 +482,12 @@ async fn async_main() -> anyhow::Result<()> {
.map(humantime::Duration::into)
.unwrap_or(Duration::MAX),
handle_ps_local_disk_loss: args.handle_ps_local_disk_loss,
persistence_config,
};
// Validate that we can connect to the database
Persistence::await_connection(&secrets.database_url, args.db_connect_timeout.into()).await?;
let persistence = Arc::new(Persistence::new(secrets.database_url, persistence_config).await);
let persistence = Arc::new(Persistence::new(secrets.database_url).await);
let service = Service::spawn(config, persistence.clone()).await?;

View File

@@ -46,31 +46,11 @@ impl TenantShardDrain {
&self,
tenants: &BTreeMap<TenantShardId, TenantShard>,
scheduler: &Scheduler,
) -> TenantShardDrainAction {
let Some(tenant_shard) = tenants.get(&self.tenant_shard_id) else {
return TenantShardDrainAction::Skip;
};
) -> Option<NodeId> {
let tenant_shard = tenants.get(&self.tenant_shard_id)?;
if *tenant_shard.intent.get_attached() != Some(self.drained_node) {
// If the intent attached node is not the drained node, check the observed state
// of the shard on the drained node. If it is Attached*, it means the shard is
// beeing migrated from the drained node. The drain loop needs to wait for the
// reconciliation to complete for a smooth draining.
use pageserver_api::models::LocationConfigMode::*;
let attach_mode = tenant_shard
.observed
.locations
.get(&self.drained_node)
.and_then(|observed| observed.conf.as_ref().map(|conf| conf.mode));
return match (attach_mode, tenant_shard.intent.get_attached()) {
(Some(AttachedSingle | AttachedMulti | AttachedStale), Some(intent_node_id)) => {
TenantShardDrainAction::Reconcile(*intent_node_id)
}
_ => TenantShardDrainAction::Skip,
};
return None;
}
// Only tenants with a normal (Active) scheduling policy are proactively moved
@@ -83,19 +63,19 @@ impl TenantShardDrain {
}
ShardSchedulingPolicy::Pause | ShardSchedulingPolicy::Stop => {
// If we have been asked to avoid rescheduling this shard, then do not migrate it during a drain
return TenantShardDrainAction::Skip;
return None;
}
}
match tenant_shard.preferred_secondary(scheduler) {
Some(node) => TenantShardDrainAction::RescheduleToSecondary(node),
Some(node) => Some(node),
None => {
tracing::warn!(
tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(),
"No eligible secondary while draining {}", self.drained_node
);
TenantShardDrainAction::Skip
None
}
}
}
@@ -158,17 +138,3 @@ impl TenantShardDrain {
}
}
}
/// Action to take when draining a tenant shard.
pub(crate) enum TenantShardDrainAction {
/// The tenant shard is on the draining node.
/// Reschedule the tenant shard to a secondary location.
/// Holds a destination node id to reschedule to.
RescheduleToSecondary(NodeId),
/// The tenant shard is beeing migrated from the draining node.
/// Wait for the reconciliation to complete.
/// Holds the intent attached node id.
Reconcile(NodeId),
/// The tenant shard is not eligible for drainining, skip it.
Skip,
}

View File

@@ -20,8 +20,7 @@ use futures::future::BoxFuture;
use itertools::Itertools;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy,
SCSafekeeperTimelinesResponse, SafekeeperDescribeResponse, ShardSchedulingPolicy,
SkSchedulingPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
};
use pageserver_api::models::{ShardImportStatus, TenantConfig};
use pageserver_api::shard::{
@@ -38,19 +37,10 @@ use utils::id::{NodeId, TenantId, TimelineId};
use utils::lsn::Lsn;
use self::split_state::SplitState;
use crate::hadron_queries::HadronSafekeeperRow;
use crate::hadron_queries::PageserverAndSafekeeperConnectionInfo;
use crate::hadron_queries::delete_timeline_safekeepers;
use crate::hadron_queries::execute_safekeeper_list_timelines;
use crate::hadron_queries::execute_sk_upsert;
use crate::hadron_queries::get_pageserver_and_safekeeper_connection_info;
use crate::hadron_queries::idempotently_persist_or_get_existing_timeline_safekeepers;
use crate::hadron_queries::scan_safekeepers_and_scheduled_timelines;
use crate::metrics::{
DatabaseQueryErrorLabelGroup, DatabaseQueryLatencyLabelGroup, METRICS_REGISTRY,
};
use crate::node::Node;
use crate::sk_node::SafeKeeperNode;
use crate::timeline_import::{
TimelineImport, TimelineImportUpdateError, TimelineImportUpdateFollowUp,
};
@@ -87,38 +77,6 @@ pub struct Persistence {
connection_pool: Pool<AsyncPgConnection>,
}
#[derive(Copy, Clone)]
pub struct PersistenceConfig {
max_connections: u32,
idle_connection_timeout: Duration,
max_connection_lifetime: Duration,
}
impl PersistenceConfig {
// If unspecified, use neon.com defaults
//
// The default postgres connection limit is 100. We use up to 99, to leave one free for a human admin under
// normal circumstances. This assumes we have exclusive use of the database cluster to which we connect.
pub const MAX_CONNECTIONS_DEFAULT: u32 = 99;
// We don't want to keep a lot of connections alive: close them down promptly if they aren't being used.
pub const IDLE_CONNECTION_TIMEOUT_DEFAULT: Duration = Duration::from_secs(10);
pub const MAX_CONNECTION_LIFETIME_DEFAULT: Duration = Duration::from_secs(60);
pub fn new(
max_connections: Option<u32>,
idle_connection_timeout: Option<Duration>,
max_connection_lifetime: Option<Duration>,
) -> Self {
PersistenceConfig {
max_connections: max_connections.unwrap_or(Self::MAX_CONNECTIONS_DEFAULT),
idle_connection_timeout: idle_connection_timeout
.unwrap_or(Self::IDLE_CONNECTION_TIMEOUT_DEFAULT),
max_connection_lifetime: max_connection_lifetime
.unwrap_or(Self::MAX_CONNECTION_LIFETIME_DEFAULT),
}
}
}
/// Legacy format, for use in JSON compat objects in test environment
#[derive(Serialize, Deserialize)]
struct JsonPersistence {
@@ -185,20 +143,6 @@ pub(crate) enum DatabaseOperation {
DeleteTimelineImport,
ListTimelineImports,
IsTenantImportingTimeline,
// Brickstore Hadron
UpsertSafeKeeperNode,
LoadSafeKeepersAndEndpoints,
EnsureHadronEndpointTransaction,
DeleteHadronEndpoint,
GetHadronEndpointInfo,
FetchComputeSpec,
GetTenandIdByEndpointId,
GetTenantShardsByEndpointId,
GetComputeNamesByTenantId,
GetOrCreateHadronTimelineSafekeeper,
FetchPageServerAndSafeKeeperConnections,
DeleteHadronTimeline,
ListSafekeeperTimelines,
}
#[must_use]
@@ -235,7 +179,11 @@ impl Persistence {
// normal circumstances. This assumes we have exclusive use of the database cluster to which we connect.
pub const MAX_CONNECTIONS: u32 = 99;
pub async fn new(database_url: String, config: PersistenceConfig) -> Self {
// We don't want to keep a lot of connections alive: close them down promptly if they aren't being used.
const IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(60);
pub async fn new(database_url: String) -> Self {
let mut mgr_config = ManagerConfig::default();
mgr_config.custom_setup = Box::new(establish_connection_rustls);
@@ -247,9 +195,9 @@ impl Persistence {
// We will use a connection pool: this is primarily to _limit_ our connection count, rather than to optimize time
// to execute queries (database queries are not generally on latency-sensitive paths).
let connection_pool = Pool::builder()
.max_size(config.max_connections)
.max_lifetime(Some(config.max_connection_lifetime))
.idle_timeout(Some(config.idle_connection_timeout))
.max_size(Self::MAX_CONNECTIONS)
.max_lifetime(Some(Self::MAX_CONNECTION_LIFETIME))
.idle_timeout(Some(Self::IDLE_CONNECTION_TIMEOUT))
// Always keep at least one connection ready to go
.min_idle(Some(1))
.test_on_check_out(true)
@@ -2187,134 +2135,6 @@ impl Persistence {
})
.await
}
////////////////////////////////////////////////////////////////
//////////////////////// Hadron methods ////////////////////////
//////////////////////// (Brickstore) //////////////////////////
////////////////////////////////////////////////////////////////
/// Upsert a SafeKeeper node.
#[allow(unused)]
pub(crate) async fn upsert_sk_node(&self, sk_node: &SafeKeeperNode) -> DatabaseResult<()> {
let sk_row = sk_node.to_database_row();
self.with_measured_conn(DatabaseOperation::UpsertSafeKeeperNode, move |conn| {
// Incantation to make the borrow checker happy
let sk_row_clone = sk_row.clone();
Box::pin(async move { execute_sk_upsert(conn, sk_row_clone).await })
})
.await
}
/// Load all Safe Keeper nodes and their scheduled endpoints from the database. This method is called at startup to
/// populate the SafeKeeperScheduler.
#[allow(unused)]
pub(crate) async fn load_safekeeper_scheduling_data(
&self,
) -> DatabaseResult<HashMap<NodeId, SafeKeeperNode>> {
let sk_nodes: HashMap<NodeId, SafeKeeperNode> = self
.with_measured_conn(
DatabaseOperation::LoadSafeKeepersAndEndpoints,
move |conn| {
// Retrieve all Safe Keeper nodes from the hadron_safekeepers table, and all timelines (grouped by
// safe keeper IDs) from the hadron_timeline_safekeepers table.
Box::pin(async move { scan_safekeepers_and_scheduled_timelines(conn).await })
},
)
.await?;
tracing::info!(
"load_safekeepers_and_endpoints: loaded {} safekeepers",
sk_nodes.len()
);
Ok(sk_nodes)
}
#[allow(unused)]
pub(crate) async fn get_or_assign_safekeepers_to_timeline(
&self,
timeline_id: TimelineId,
safekeepers: Vec<NodeId>,
) -> DatabaseResult<Vec<NodeId>> {
self.with_measured_conn(
DatabaseOperation::GetOrCreateHadronTimelineSafekeeper,
move |conn| {
let safekeepers_clone = safekeepers.clone();
Box::pin(async move {
idempotently_persist_or_get_existing_timeline_safekeepers(
conn,
timeline_id,
&safekeepers_clone,
)
.await
})
},
)
.await
}
#[allow(unused)]
pub(crate) async fn delete_hadron_timeline_safekeepers(
&self,
timeline_id: TimelineId,
) -> DatabaseResult<()> {
self.with_measured_conn(DatabaseOperation::DeleteHadronTimeline, move |conn| {
Box::pin(async move {
delete_timeline_safekeepers(conn, timeline_id).await?;
Ok(())
})
})
.await
}
#[allow(unused)]
pub(crate) async fn get_pageserver_and_safekeepers(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> DatabaseResult<PageserverAndSafekeeperConnectionInfo> {
self.with_measured_conn(
DatabaseOperation::FetchPageServerAndSafeKeeperConnections,
move |conn| {
Box::pin(async move {
get_pageserver_and_safekeeper_connection_info(conn, tenant_id, timeline_id)
.await
})
},
)
.await
}
#[allow(unused)]
pub(crate) async fn list_hadron_safekeepers(&self) -> DatabaseResult<Vec<HadronSafekeeperRow>> {
let safekeepers: Vec<HadronSafekeeperRow> = self
.with_measured_conn(DatabaseOperation::ListNodes, move |conn| {
Box::pin(async move {
Ok(crate::schema::hadron_safekeepers::table
.load::<HadronSafekeeperRow>(conn)
.await?)
})
})
.await?;
tracing::info!(
"list_hadron_safekeepers: loaded {} nodes",
safekeepers.len()
);
Ok(safekeepers)
}
#[allow(unused)]
pub(crate) async fn safekeeper_list_timelines(
&self,
id: i64,
) -> DatabaseResult<SCSafekeeperTimelinesResponse> {
self.with_measured_conn(DatabaseOperation::ListSafekeeperTimelines, move |conn| {
Box::pin(async move { execute_safekeeper_list_timelines(conn, id).await })
})
.await
}
}
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
@@ -2403,45 +2223,15 @@ fn client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
})
}
// Hadron's implementation of establish_connection_rustls which avoids hogging the tokio executor thread during
// CPU-intensive operations in postgres connection and session establishments.
// Compared to the original implementation this function performs the following tasks using spawn_blocking to avoid
// hogging the tokio executor thread:
// 1. Parsing and decoding root certificates during rustls client config setup.
// 2. The tokio_postgres::connect() call, which performs the TLS handshake and the postgres password authentication.
fn establish_connection_rustls(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
let fut = async move {
let fut = async {
// We first set up the way we want rustls to work.
let rustls_config = tokio::task::spawn_blocking(client_config_with_root_certs)
.await
.map_err(|e| {
ConnectionError::BadConnection(format!(
"Error in spawn_blocking client_config_with_root_certs: {e}"
))
})
.and_then(|r| {
r.map_err(|e| {
ConnectionError::BadConnection(format!(
"Error in client_config_with_root_certs: {e}"
))
})
})?;
let rustls_config = client_config_with_root_certs()
.map_err(|err| ConnectionError::BadConnection(format!("{err:?}")))?;
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
// Perform the expensive TLS handshake and SCRAM SHA calculations in a blocking task
let task_owned_config = config.to_owned();
let (client, conn) = tokio::task::spawn_blocking(move || {
tokio::runtime::Handle::current()
.block_on(async { tokio_postgres::connect(&task_owned_config, tls).await })
})
.await
.map_err(|e| {
ConnectionError::BadConnection(format!(
"Error in spawn_blocking tokio_postgres::connect: {e}"
))
})
.and_then(|r| r.map_err(|e| ConnectionError::BadConnection(e.to_string())))?;
let (client, conn) = tokio_postgres::connect(config, tls)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
AsyncPgConnection::try_from_client_and_connection(client, conn).await
};

View File

@@ -79,13 +79,13 @@ use crate::id_lock_map::{
use crate::leadership::Leadership;
use crate::metrics;
use crate::node::{AvailabilityTransition, Node};
use crate::operation_utils::{self, TenantShardDrain, TenantShardDrainAction};
use crate::operation_utils::{self, TenantShardDrain};
use crate::pageserver_client::PageserverClient;
use crate::peer_client::GlobalObservedState;
use crate::persistence::split_state::SplitState;
use crate::persistence::{
AbortShardSplitStatus, ControllerPersistence, DatabaseError, DatabaseResult,
MetadataHealthPersistence, Persistence, PersistenceConfig, ShardGenerationState, TenantFilter,
MetadataHealthPersistence, Persistence, ShardGenerationState, TenantFilter,
TenantShardPersistence,
};
use crate::reconciler::{
@@ -490,8 +490,6 @@ pub struct Config {
// Feature flag: Whether the storage controller should act to rectify pageserver-reported local disk loss.
pub handle_ps_local_disk_loss: bool,
pub persistence_config: PersistenceConfig,
}
impl From<DatabaseError> for ApiError {
@@ -1276,7 +1274,7 @@ impl Service {
// Always attempt autosplits. Sharding is crucial for bulk ingest performance, so we
// must be responsive when new projects begin ingesting and reach the threshold.
self.autosplit_tenants().await;
},
}
_ = self.reconcilers_cancel.cancelled() => return
}
}
@@ -8878,9 +8876,6 @@ impl Service {
for (_tenant_id, schedule_context, shards) in
TenantShardExclusiveIterator::new(tenants, ScheduleMode::Speculative)
{
if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS {
break;
}
for shard in shards {
if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS {
break;
@@ -9645,16 +9640,16 @@ impl Service {
tenant_shard_id: tid,
};
let drain_action = {
let dest_node_id = {
let locked = self.inner.read().unwrap();
tid_drain.tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler)
};
let dest_node_id = match drain_action {
TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => dest_node_id,
TenantShardDrainAction::Reconcile(intent_node_id) => intent_node_id,
TenantShardDrainAction::Skip => {
continue;
match tid_drain
.tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler)
{
Some(node_id) => node_id,
None => {
continue;
}
}
};
@@ -9689,16 +9684,14 @@ impl Service {
{
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, scheduler) = locked.parts_mut();
let rescheduled = tid_drain.reschedule_to_secondary(
dest_node_id,
tenants,
scheduler,
nodes,
)?;
let tenant_shard = match drain_action {
TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => tid_drain
.reschedule_to_secondary(dest_node_id, tenants, scheduler, nodes)?,
TenantShardDrainAction::Reconcile(_) => tenants.get_mut(&tid),
// Note: Unreachable, handled above.
TenantShardDrainAction::Skip => None,
};
if let Some(tenant_shard) = tenant_shard {
if let Some(tenant_shard) = rescheduled {
let waiter = self.maybe_configured_reconcile_shard(
tenant_shard,
nodes,

View File

@@ -1,56 +0,0 @@
use serde::Serialize;
use std::collections::{HashMap, HashSet};
use utils::id::{NodeId, TimelineId};
use uuid::Uuid;
use crate::hadron_queries::HadronSafekeeperRow;
// In-memory representation of a Safe Keeper node.
#[derive(Clone, Serialize)]
pub(crate) struct SafeKeeperNode {
pub(crate) id: NodeId,
pub(crate) listen_http_addr: String,
pub(crate) listen_http_port: u16,
pub(crate) listen_pg_addr: String,
pub(crate) listen_pg_port: u16,
// All timelines scheduled to this SK node. Some of the timelines may be associated with
// a legacy "endpoint", a deprecated concept used in HCC compute CRUD APIs. The "endpoint"
// concept will be retired after Public Preview launch.
pub(crate) timelines: HashSet<TimelineId>,
// All legacy endpoints and their associated timelines scheduled to this SK node.
// Invariant: The timelines referenced in this map must be present in the `timelines` set above.
pub(crate) legacy_endpoints: HashMap<Uuid, TimelineId>,
}
impl SafeKeeperNode {
#[allow(unused)]
pub(crate) fn new(
id: NodeId,
listen_http_addr: String,
listen_http_port: u16,
listen_pg_addr: String,
listen_pg_port: u16,
) -> Self {
Self {
id,
listen_http_addr,
listen_http_port,
listen_pg_addr,
listen_pg_port,
legacy_endpoints: HashMap::new(),
timelines: HashSet::new(),
}
}
#[allow(unused)]
pub(crate) fn to_database_row(&self) -> HadronSafekeeperRow {
HadronSafekeeperRow {
sk_node_id: self.id.0 as i64,
listen_http_addr: self.listen_http_addr.clone(),
listen_http_port: self.listen_http_port as i32,
listen_pg_addr: self.listen_pg_addr.clone(),
listen_pg_port: self.listen_pg_port as i32,
}
}
}

View File

@@ -812,6 +812,8 @@ impl TenantShard {
/// if the swap is not possible and leaves the intent state in its original state.
///
/// Arguments:
/// `attached_to`: the currently attached location matching the intent state (may be None if the
/// shard is not attached)
/// `promote_to`: an optional secondary location of this tenant shard. If set to None, we ask
/// the scheduler to recommend a node
pub(crate) fn reschedule_to_secondary(

View File

@@ -55,7 +55,7 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.6",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
all_versions = ["1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
all_versions = ["1.7", "1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
current_version = "1.6"
for idx, begin_version in enumerate(all_versions):
for target_version in all_versions[idx + 1 :]:

View File

@@ -129,10 +129,7 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
Test static endpoint is protected from GC by acquiring and renewing lsn leases.
"""
LSN_LEASE_LENGTH = (
14 # This value needs to be large enough for compute_ctl to send two lease requests.
)
LSN_LEASE_LENGTH = 8
neon_env_builder.num_pageservers = 2
# GC is manual triggered.
env = neon_env_builder.init_start(
@@ -233,15 +230,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
log.info(f"`SELECT` query succeed after GC, {ctx=}")
return offset
# It's not reliable to let the compute renew the lease in this test case as we have a very tight
# lease timeout. Therefore, the test case itself will renew the lease.
#
# This is a workaround to make the test case more deterministic.
def renew_lease(env: NeonEnv, lease_lsn: Lsn):
env.storage_controller.pageserver_api().timeline_lsn_lease(
env.initial_tenant, env.initial_timeline, lease_lsn
)
# Insert some records on main branch
with env.endpoints.create_start("main", config_lines=["shared_buffers=1MB"]) as ep_main:
with ep_main.cursor() as cur:
@@ -254,9 +242,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
XLOG_BLCKSZ = 8192
lsn = Lsn((int(lsn) // XLOG_BLCKSZ) * XLOG_BLCKSZ)
# We need to mock the way cplane works: it gets a lease for a branch before starting the compute.
renew_lease(env, lsn)
with env.endpoints.create_start(
branch_name="main",
endpoint_id="static",
@@ -266,6 +251,9 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Wait for static compute to renew lease at least once.
time.sleep(LSN_LEASE_LENGTH / 2)
generate_updates_on_main(env, ep_main, 3, end=100)
offset = trigger_gc_and_select(
@@ -275,10 +263,10 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
# Trigger Pageserver restarts
for ps in env.pageservers:
ps.stop()
# Static compute should have at least one lease request failure due to connection.
time.sleep(LSN_LEASE_LENGTH / 2)
ps.start()
renew_lease(env, lsn)
trigger_gc_and_select(
env,
ep_static,
@@ -294,9 +282,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
)
env.storage_controller.reconcile_until_idle()
# Wait for static compute to renew lease on the new pageserver.
time.sleep(LSN_LEASE_LENGTH + 3)
trigger_gc_and_select(
env,
ep_static,
@@ -307,6 +292,7 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
# Do some update so we can increment gc_cutoff
generate_updates_on_main(env, ep_main, i, end=100)
# Wait for the existing lease to expire.
time.sleep(LSN_LEASE_LENGTH + 1)
# Now trigger GC again, layers should be removed.

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import os
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
@@ -769,14 +768,6 @@ def test_lsn_lease_storcon(neon_env_builder: NeonEnvBuilder):
"compaction_period": "0s",
}
env = neon_env_builder.init_start(initial_tenant_conf=conf)
# ShardSplit is slow in debug builds, so ignore the warning
if os.getenv("BUILD_TYPE", "debug") == "debug":
env.storage_controller.allowed_errors.extend(
[
".*Exclusive lock by ShardSplit was held.*",
]
)
with env.endpoints.create_start(
"main",
) as ep: