Compare commits

..

5 Commits

Author SHA1 Message Date
Heikki Linnakangas
16d6898e44 git add missing file 2025-06-12 02:37:59 +03:00
Heikki Linnakangas
10b936bf03 Use a custom Rust implementation to replace the LFC hash table
The new implementation lives in a separately allocated shared memory
area, which could be resized. Resizing it isn't actually implemented
yet, though. It would require some co-operation from the LFC code.
2025-06-05 18:31:29 +03:00
Heikki Linnakangas
6145cfd1c2 Move neon-shmem facility to separate module within the crate 2025-06-05 18:13:03 +03:00
Heikki Linnakangas
96b4de1de6 Make LFC chunk size a compile-time constant
A runtime setting is nicer, but the next commit will replace the hash
table with a different implementation that requires the value size to
be a compile-time constant.
2025-06-05 18:08:40 +03:00
Heikki Linnakangas
9fdf5fbb7e Use a separate freelist to track LFC "holes"
When the LFC is shrunk, we punch holes in the underlying file to
release the disk space to the OS. We tracked it in the same hash table
as the in-use entries, because that was convenient. However, I'm
working on being able to shrink the hash table too, and once we do
that, we'll need some other place to track the holes. Implement a
simple scheme of an in-memory array and a chain of on-disk blocks for
that.
2025-06-05 18:08:35 +03:00
94 changed files with 3021 additions and 2359 deletions

132
Cargo.lock generated
View File

@@ -1086,6 +1086,25 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
dependencies = [
"clap",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1212,7 +1231,7 @@ version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -1270,6 +1289,14 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"cbindgen",
"neon-shmem",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -1278,13 +1305,10 @@ dependencies = [
"chrono",
"indexmap 2.9.0",
"jsonwebtoken",
"postgres",
"regex",
"remote_storage",
"serde",
"serde_json",
"tokio-postgres",
"url",
"utils",
]
@@ -1448,7 +1472,6 @@ dependencies = [
"regex",
"reqwest",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
"serde",
"serde_json",
@@ -1940,7 +1963,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
dependencies = [
"darling",
"either",
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -2058,7 +2081,6 @@ dependencies = [
"axum-extra",
"camino",
"camino-tempfile",
"clap",
"futures",
"http-body-util",
"itertools 0.10.5",
@@ -2505,6 +2527,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2717,6 +2751,12 @@ dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3653,7 +3693,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3715,7 +3755,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"twox-hash",
]
@@ -3804,6 +3844,8 @@ name = "neon-shmem"
version = "0.1.0"
dependencies = [
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -5097,7 +5139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
@@ -5118,7 +5160,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
@@ -5243,7 +5285,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"rcgen",
"redis",
"regex",
@@ -5347,6 +5389,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5371,6 +5419,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5391,6 +5449,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5409,6 +5477,15 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5419,6 +5496,16 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6750,7 +6837,6 @@ dependencies = [
"chrono",
"clap",
"clashmap",
"compute_api",
"control_plane",
"cron",
"diesel",
@@ -6906,7 +6992,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
@@ -8205,6 +8291,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8562,6 +8657,15 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"

View File

@@ -44,6 +44,7 @@ members = [
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -251,6 +252,7 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
@@ -278,6 +280,7 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -110,19 +110,6 @@ RUN set -e \
# System postgres for use with client libraries (e.g. in storage controller)
postgresql-15 \
openssl \
unzip \
curl \
&& ARCH=$(uname -m) \
&& if [ "$ARCH" = "x86_64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
elif [ "$ARCH" = "aarch64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
else \
echo "Unsupported architecture: $ARCH" && exit 1; \
fi \
&& unzip awscliv2.zip \
&& ./aws/install \
&& rm -rf aws awscliv2.zip \
&& rm -f /etc/apt/apt.conf.d/80-retries \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& useradd -d /data neon \

View File

@@ -18,10 +18,12 @@ ifeq ($(BUILD_TYPE),release)
PG_LDFLAGS = $(LDFLAGS)
# Unfortunately, `--profile=...` is a nightly feature
CARGO_BUILD_FLAGS += --release
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -180,11 +182,16 @@ postgres-check-%: postgres-%
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
+@echo "Compiling communicator $*"
$(CARGO_CMD_PREFIX) cargo build -p communicator $(CARGO_BUILD_FLAGS)
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
LIBCOMMUNICATOR_PATH=$(NEON_CARGO_ARTIFACT_TARGET_DIR) \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \

View File

@@ -145,47 +145,28 @@ impl Cli {
}
}
impl Cli {
pub fn get_config(&self) -> Result<ComputeConfig> {
// First, read the config from the path if provided
if let Some(ref config) = self.config {
let file = File::open(config)?;
return Ok(serde_json::from_reader(&file)?);
}
// If the config wasn't provided in the CLI arguments, then retrieve it from
// the control plane
match get_config_from_control_plane(
self.control_plane_uri.as_ref().unwrap(),
&self.compute_id,
) {
Ok(config) => Ok(config),
Err(e) => {
error!(
"cannot get response from control plane: {}\n\
neither spec nor confirmation that compute is in the Empty state was received",
e
);
Err(e)
}
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
fn main() -> Result<()> {
let cli = Cli::parse();
let scenario = failpoint_support::init();
init().await?;
// For historical reasons, the main thread that processes the config and launches postgres
// is synchronous, but we always have this tokio runtime available and we "enter" it so
// that you can use tokio::spawn() and tokio::runtime::Handle::current().block_on(...)
// from all parts of compute_ctl.
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let _rt_guard = runtime.enter();
runtime.block_on(init())?;
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?;
let config = cli.get_config()?;
let config = get_config(&cli)?;
let compute_node = ComputeNode::new(
ComputeNodeParams {
@@ -210,7 +191,7 @@ async fn main() -> Result<()> {
config,
)?;
let exit_code = compute_node.run().await?;
let exit_code = compute_node.run()?;
scenario.teardown();
@@ -232,6 +213,28 @@ async fn init() -> Result<()> {
Ok(())
}
fn get_config(cli: &Cli) -> Result<ComputeConfig> {
// First, read the config from the path if provided
if let Some(ref config) = cli.config {
let file = File::open(config)?;
return Ok(serde_json::from_reader(&file)?);
}
// If the config wasn't provided in the CLI arguments, then retrieve it from
// the control plane
match get_config_from_control_plane(cli.control_plane_uri.as_ref().unwrap(), &cli.compute_id) {
Ok(config) => Ok(config),
Err(e) => {
error!(
"cannot get response from control plane: {}\n\
neither spec nor confirmation that compute is in the Empty state was received",
e
);
Err(e)
}
}
}
fn deinit_and_exit(exit_code: Option<i32>) -> ! {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may

View File

@@ -15,8 +15,12 @@ use itertools::Itertools;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
@@ -26,9 +30,9 @@ use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tokio_postgres::{NoTls, error::SqlState};
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
@@ -140,7 +144,7 @@ pub struct ComputeState {
/// Compute spec. This can be received from the CLI or - more likely -
/// passed by the control plane with a /configure HTTP request.
pub spec: Option<ComputeSpec>,
pub pspec: Option<ParsedSpec>,
/// If the spec is passed by a /configure request, 'startup_span' is the
/// /configure request's tracing span. The main thread enters it when it
@@ -167,7 +171,7 @@ impl ComputeState {
status: ComputeStatus::Empty,
last_active: None,
error: None,
spec: None,
pspec: None,
startup_span: None,
metrics: ComputeMetrics::default(),
lfc_prewarm_state: LfcPrewarmState::default(),
@@ -199,6 +203,94 @@ impl Default for ComputeState {
}
}
#[derive(Clone, Debug)]
pub struct ParsedSpec {
pub spec: ComputeSpec,
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
pub endpoint_storage_addr: Option<SocketAddr>,
pub endpoint_storage_token: Option<String>,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
// Extract the options from the spec file that are needed to connect to
// the storage system.
//
// For backwards-compatibility, the top-level fields in the spec file
// may be empty. In that case, we need to dig them from the GUCs in the
// cluster.settings field.
let pageserver_connstr = spec
.pageserver_connstring
.clone()
.or_else(|| spec.cluster.settings.find("neon.pageserver_connstring"))
.ok_or("pageserver connstr should be provided")?;
let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() {
if matches!(spec.mode, ComputeMode::Primary) {
spec.cluster
.settings
.find("neon.safekeepers")
.ok_or("safekeeper connstrings should be provided")?
.split(',')
.map(|str| str.to_string())
.collect()
} else {
vec![]
}
} else {
spec.safekeeper_connstrings.clone()
};
let storage_auth_token = spec.storage_auth_token.clone();
let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id {
tenant_id
} else {
spec.cluster
.settings
.find("neon.tenant_id")
.ok_or("tenant id should be provided")
.map(|s| TenantId::from_str(&s))?
.or(Err("invalid tenant id"))?
};
let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id {
timeline_id
} else {
spec.cluster
.settings
.find("neon.timeline_id")
.ok_or("timeline id should be provided")
.map(|s| TimelineId::from_str(&s))?
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<SocketAddr> = spec
.endpoint_storage_addr
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"))
.unwrap_or_default()
.parse()
.ok();
let endpoint_storage_token = spec
.endpoint_storage_token
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_token"));
Ok(ParsedSpec {
spec,
pageserver_connstr,
safekeeper_connstrings,
storage_auth_token,
tenant_id,
timeline_id,
endpoint_storage_addr,
endpoint_storage_token,
})
}
}
/// If we are a VM, returns a [`Command`] that will run in the `neon-postgres`
/// cgroup. Otherwise returns the default `Command::new(cmd)`
///
@@ -276,7 +368,10 @@ impl ComputeNode {
tokio_conn_conf.options(&options);
let mut new_state = ComputeState::new();
new_state.spec = config.spec;
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
}
Ok(ComputeNode {
params,
@@ -291,10 +386,10 @@ impl ComputeNode {
/// Top-level control flow of compute_ctl. Returns a process exit code we should
/// exit with.
pub async fn run(self) -> Result<Option<i32>> {
pub fn run(self) -> Result<Option<i32>> {
let this = Arc::new(self);
let cli_spec = this.state.lock().unwrap().spec.clone();
let cli_spec = this.state.lock().unwrap().pspec.clone();
// If this is a pooled VM, prewarm before starting HTTP server and becoming
// available for binding. Prewarming helps Postgres start quicker later,
@@ -330,7 +425,7 @@ impl ComputeNode {
// If we got a spec from the CLI already, use that. Otherwise wait for the
// control plane to pass it to us with a /configure HTTP request
let spec = if let Some(cli_spec) = cli_spec {
let pspec = if let Some(cli_spec) = cli_spec {
cli_spec
} else {
this.wait_spec()?
@@ -343,11 +438,11 @@ impl ComputeNode {
let mut vm_monitor = None;
let mut pg_process: Option<PostgresHandle> = None;
match this.start_compute(&mut pg_process).await {
match this.start_compute(&mut pg_process) {
Ok(()) => {
// Success! Launch remaining services (just vm-monitor currently)
vm_monitor =
Some(this.start_vm_monitor(spec.disable_lfc_resizing.unwrap_or(false)));
Some(this.start_vm_monitor(pspec.spec.disable_lfc_resizing.unwrap_or(false)));
}
Err(err) => {
// Something went wrong with the startup. Log it and expose the error to
@@ -392,7 +487,7 @@ impl ComputeNode {
}
// Reap the postgres process
delay_exit |= this.cleanup_after_postgres_exit().await?;
delay_exit |= this.cleanup_after_postgres_exit()?;
// If launch failed, keep serving HTTP requests for a while, so the cloud
// control plane can get the actual error.
@@ -403,7 +498,7 @@ impl ComputeNode {
Ok(exit_code)
}
pub fn wait_spec(&self) -> Result<ComputeSpec> {
pub fn wait_spec(&self) -> Result<ParsedSpec> {
info!("no compute spec provided, waiting");
let mut state = self.state.lock().unwrap();
while state.status != ComputeStatus::ConfigurationPending {
@@ -411,7 +506,7 @@ impl ComputeNode {
}
info!("got spec, continue configuration");
let spec = state.spec.as_ref().unwrap().clone();
let spec = state.pspec.as_ref().unwrap().clone();
// Record for how long we slept waiting for the spec.
let now = Utc::now();
@@ -444,7 +539,7 @@ impl ComputeNode {
///
/// Note that this is in the critical path of a compute cold start. Keep this fast.
/// Try to do things concurrently, to hide the latencies.
async fn start_compute(self: &Arc<Self>, pg_handle: &mut Option<PostgresHandle>) -> Result<()> {
fn start_compute(self: &Arc<Self>, pg_handle: &mut Option<PostgresHandle>) -> Result<()> {
let compute_state: ComputeState;
let start_compute_span;
@@ -479,17 +574,18 @@ impl ComputeNode {
compute_state = state_guard.clone()
}
let spec = compute_state.spec.as_ref().expect("spec must be set");
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
info!(
"starting compute for operation {}, tenant {}, timeline {}, project {}, branch {}, endpoint {}, features {:?}, spec.remote_extensions {:?}",
spec.operation_uuid.as_deref().unwrap_or("None"),
spec.tenant_id,
spec.timeline_id,
spec.project_id,
spec.branch_id,
spec.endpoint_id,
spec.features,
spec.remote_extensions,
"starting compute for project {}, operation {}, tenant {}, timeline {}, project {}, branch {}, endpoint {}, features {:?}, spec.remote_extensions {:?}",
pspec.spec.cluster.cluster_id.as_deref().unwrap_or("None"),
pspec.spec.operation_uuid.as_deref().unwrap_or("None"),
pspec.tenant_id,
pspec.timeline_id,
pspec.spec.project_id.as_deref().unwrap_or("None"),
pspec.spec.branch_id.as_deref().unwrap_or("None"),
pspec.spec.endpoint_id.as_deref().unwrap_or("None"),
pspec.spec.features,
pspec.spec.remote_extensions,
);
////// PRE-STARTUP PHASE: things that need to be finished before we start the Postgres process
@@ -510,8 +606,8 @@ impl ComputeNode {
let tls_config = self.tls_config(&pspec.spec);
// If there are any remote extensions in shared_preload_libraries, start downloading them
if spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), spec.clone());
if pspec.spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), pspec.spec.clone());
pre_tasks.spawn(async move {
this.download_preload_extensions(&spec)
.in_current_span()
@@ -522,11 +618,13 @@ impl ComputeNode {
// Prepare pgdata directory. This downloads the basebackup, among other things.
{
let (this, cs) = (self.clone(), compute_state.clone());
pre_tasks.spawn(async move { this.prepare_pgdata(&cs).await });
pre_tasks.spawn_blocking_child(move || this.prepare_pgdata(&cs));
}
// Resize swap to the desired size if the compute spec says so
if let (Some(size_bytes), true) = (spec.swap_size_bytes, self.params.resize_swap_on_bind) {
if let (Some(size_bytes), true) =
(pspec.spec.swap_size_bytes, self.params.resize_swap_on_bind)
{
pre_tasks.spawn_blocking_child(move || {
// To avoid 'swapoff' hitting postgres startup, we need to run resize-swap to completion
// *before* starting postgres.
@@ -544,7 +642,7 @@ impl ComputeNode {
// Set disk quota if the compute spec says so
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) = (
spec.disk_quota_bytes,
pspec.spec.disk_quota_bytes,
self.params.set_disk_quota_for_fs.as_ref(),
) {
let disk_quota_fs_mountpoint = disk_quota_fs_mountpoint.clone();
@@ -559,7 +657,7 @@ impl ComputeNode {
}
// tune pgbouncer
if let Some(pgbouncer_settings) = &spec.pgbouncer_settings {
if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
@@ -577,7 +675,7 @@ impl ComputeNode {
}
// configure local_proxy
if let Some(local_proxy) = &spec.local_proxy_config {
if let Some(local_proxy) = &pspec.spec.local_proxy_config {
info!("configuring local_proxy");
// Spawn a background task to do the configuration,
@@ -595,7 +693,7 @@ impl ComputeNode {
}
// Configure and start rsyslog for compliance audit logging
match spec.audit_log_level {
match pspec.spec.audit_log_level {
ComputeAudit::Hipaa | ComputeAudit::Extended | ComputeAudit::Full => {
let remote_endpoint =
std::env::var("AUDIT_LOGGING_ENDPOINT").unwrap_or("".to_string());
@@ -606,10 +704,16 @@ impl ComputeNode {
let log_directory_path = Path::new(&self.params.pgdata).join("log");
let log_directory_path = log_directory_path.to_string_lossy().to_string();
// Add project_id,endpoint_id to identify the logs.
//
// These ids are passed from cplane,
let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or("");
let project_id = pspec.spec.project_id.as_deref().unwrap_or("");
configure_audit_rsyslog(
log_directory_path.clone(),
&spec.endpoint_id,
&spec.project_id,
endpoint_id,
project_id,
&remote_endpoint,
)?;
@@ -620,7 +724,7 @@ impl ComputeNode {
}
// Configure and start rsyslog for Postgres logs export
let conf = PostgresLogsRsyslogConfig::new(spec.logs_export_host.as_deref());
let conf = PostgresLogsRsyslogConfig::new(pspec.spec.logs_export_host.as_deref());
configure_postgres_logs_export(conf)?;
// Launch remaining service threads
@@ -628,20 +732,21 @@ impl ComputeNode {
let _configurator_handle = launch_configurator(self);
// Wait for all the pre-tasks to finish before starting postgres
while let Some(res) = pre_tasks.join_next().await {
let rt = tokio::runtime::Handle::current();
while let Some(res) = rt.block_on(pre_tasks.join_next()) {
res??;
}
////// START POSTGRES
let start_time = Utc::now();
let pg_process = self.start_postgres(spec.storage_auth_token.clone())?;
let pg_process = self.start_postgres(pspec.storage_auth_token.clone())?;
let postmaster_pid = pg_process.pid();
*pg_handle = Some(pg_process);
// If this is a primary endpoint, perform some post-startup configuration before
// opening it up for the world.
let config_time = Utc::now();
if spec.mode == ComputeMode::Primary {
if pspec.spec.mode == ComputeMode::Primary {
self.configure_as_primary(&compute_state)?;
let conf = self.get_tokio_conn_conf(None);
@@ -682,7 +787,6 @@ impl ComputeNode {
if pspec.spec.autoprewarm {
self.prewarm_lfc();
}
Ok(())
}
@@ -758,14 +862,14 @@ impl ComputeNode {
}
}
async fn cleanup_after_postgres_exit(&self) -> Result<bool> {
fn cleanup_after_postgres_exit(&self) -> Result<bool> {
// Maybe sync safekeepers again, to speed up next startup
let compute_state = self.state.lock().unwrap().clone();
let spec = compute_state.spec.as_ref().expect("spec must be set");
if matches!(spec.mode, compute_api::spec::ComputeMode::Primary) {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
if matches!(pspec.spec.mode, compute_api::spec::ComputeMode::Primary) {
info!("syncing safekeepers on shutdown");
let storage_auth_token = spec.storage_auth_token.clone();
let lsn = self.sync_safekeepers(storage_auth_token).await?;
let storage_auth_token = pspec.storage_auth_token.clone();
let lsn = self.sync_safekeepers(storage_auth_token)?;
info!("synced safekeepers at lsn {lsn}");
}
@@ -788,12 +892,13 @@ impl ComputeNode {
/// Check that compute node has corresponding feature enabled.
pub fn has_feature(&self, feature: ComputeFeature) -> bool {
self.state
.lock()
.unwrap()
.spec
.as_ref()
.is_some_and(|spec| spec.features.contains(&feature))
let state = self.state.lock().unwrap();
if let Some(s) = state.pspec.as_ref() {
s.spec.features.contains(&feature)
} else {
false
}
}
pub fn set_status(&self, status: ComputeStatus) {
@@ -810,15 +915,13 @@ impl ComputeNode {
self.state.lock().unwrap().status
}
pub fn get_timeline_id(&self) -> TimelineId {
pub fn get_timeline_id(&self) -> Option<TimelineId> {
self.state
.lock()
.unwrap()
.spec
.pspec
.as_ref()
.unwrap()
.timeline_id
.clone()
.map(|s| s.timeline_id)
}
// Remove `pgdata` directory and create it again with right permissions.
@@ -837,10 +940,11 @@ impl ComputeNode {
// unarchive it to `pgdata` directory overriding all its previous content.
#[instrument(skip_all, fields(%lsn))]
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.spec.as_ref().expect("spec must be set");
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let start_time = Instant::now();
let mut config = postgres::Config::from(&spec.pageservers[0]);
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let mut config = postgres::Config::from_str(shard0_connstr)?;
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
@@ -852,17 +956,20 @@ impl ComputeNode {
}
config.application_name("compute_ctl");
if let Some(spec) = &compute_state.spec {
config.options(&format!("-c neon.compute_mode={}", spec.mode.to_type_str()));
if let Some(spec) = &compute_state.pspec {
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
}
// Connect to pageserver
let mut client = config.connect(postgres::NoTls)?;
let mut client = config.connect(NoTls)?;
let pageserver_connect_micros = start_time.elapsed().as_micros() as u64;
let basebackup_cmd = match lsn {
Lsn(0) => {
if spec.mode != ComputeMode::Primary {
if spec.spec.mode != ComputeMode::Primary {
format!(
"basebackup {} {} --gzip --replica",
spec.tenant_id, spec.timeline_id
@@ -872,7 +979,7 @@ impl ComputeNode {
}
}
_ => {
if spec.mode != ComputeMode::Primary {
if spec.spec.mode != ComputeMode::Primary {
format!(
"basebackup {} {} {} --gzip --replica",
spec.tenant_id, spec.timeline_id, lsn
@@ -948,34 +1055,35 @@ impl ComputeNode {
compute_state: &ComputeState,
) -> Result<Option<Lsn>> {
// Construct a connection config for each safekeeper
let spec = compute_state
.spec
let pspec: ParsedSpec = compute_state
.pspec
.as_ref()
.expect("spec must be set")
.clone();
let safekeepers = spec
.safekeepers
.iter()
.map(|s| {
let mut config = tokio_postgres::Config::from(s);
let sk_connstrs: Vec<String> = pspec.safekeeper_connstrings.clone();
let sk_configs = sk_connstrs.into_iter().map(|connstr| {
// Format connstr
let id = connstr.clone();
let connstr = format!("postgresql://no_user@{}", connstr);
let options = format!(
"-c timeline_id={} tenant_id={}",
pspec.timeline_id, pspec.tenant_id
);
config.user("no_user");
config.options(&format!(
"-c timeline_id={} tenant_id={}",
spec.timeline_id, spec.tenant_id
));
if let Some(storage_auth_token) = &spec.storage_auth_token {
config.password(storage_auth_token);
}
// Construct client
let mut config = tokio_postgres::Config::from_str(&connstr).unwrap();
config.options(&options);
if let Some(storage_auth_token) = pspec.storage_auth_token.clone() {
config.password(storage_auth_token);
}
(format!("{}:{}", s.host, s.port), config)
})
.collect::<Vec<(String, tokio_postgres::Config)>>();
(id, config)
});
// Create task set to query all safekeepers
let mut tasks = FuturesUnordered::new();
let quorum = safekeepers.len() / 2 + 1;
for (id, config) in safekeepers {
let quorum = sk_configs.len() / 2 + 1;
for (id, config) in sk_configs {
let timeout = tokio::time::Duration::from_millis(100);
let task = tokio::time::timeout(timeout, ping_safekeeper(id, config));
tasks.push(tokio::spawn(task));
@@ -1020,13 +1128,11 @@ impl ComputeNode {
// Fast path for sync_safekeepers. If they're already synced we get the lsn
// in one roundtrip. If not, we should do a full sync_safekeepers.
#[instrument(skip_all)]
pub async fn check_safekeepers_synced(
&self,
compute_state: &ComputeState,
) -> Result<Option<Lsn>> {
pub fn check_safekeepers_synced(&self, compute_state: &ComputeState) -> Result<Option<Lsn>> {
let start_time = Utc::now();
let result = self.check_safekeepers_synced_async(compute_state).await;
let rt = tokio::runtime::Handle::current();
let result = rt.block_on(self.check_safekeepers_synced_async(compute_state));
// Record runtime
self.state.lock().unwrap().metrics.sync_sk_check_ms = Utc::now()
@@ -1040,7 +1146,7 @@ impl ComputeNode {
// Run `postgres` in a special mode with `--sync-safekeepers` argument
// and return the reported LSN back to the caller.
#[instrument(skip_all)]
pub async fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
pub fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
let start_time = Utc::now();
let mut sync_handle = maybe_cgexec(&self.params.pgbin)
@@ -1072,8 +1178,8 @@ impl ComputeNode {
SYNC_SAFEKEEPERS_PID.store(0, Ordering::SeqCst);
// Process has exited, so we can join the logs thread.
let _ = logs_handle
.await
let _ = tokio::runtime::Handle::current()
.block_on(logs_handle)
.map_err(|e| tracing::error!("log task panicked: {:?}", e));
if !sync_output.status.success() {
@@ -1099,8 +1205,9 @@ impl ComputeNode {
/// Do all the preparations like PGDATA directory creation, configuration,
/// safekeepers sync, basebackup, etc.
#[instrument(skip_all)]
pub async fn prepare_pgdata(&self, compute_state: &ComputeState) -> Result<()> {
let spec = compute_state.spec.as_ref().expect("spec must be set");
pub fn prepare_pgdata(&self, compute_state: &ComputeState) -> Result<()> {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = &pspec.spec;
let pgdata_path = Path::new(&self.params.pgdata);
let tls_config = self.tls_config(&pspec.spec);
@@ -1109,7 +1216,7 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
spec,
&pspec.spec,
self.params.internal_http_port,
tls_config,
)?;
@@ -1120,13 +1227,11 @@ impl ComputeNode {
let lsn = match spec.mode {
ComputeMode::Primary => {
info!("checking if safekeepers are synced");
let lsn = if let Ok(Some(lsn)) = self.check_safekeepers_synced(compute_state).await
{
let lsn = if let Ok(Some(lsn)) = self.check_safekeepers_synced(compute_state) {
lsn
} else {
info!("starting safekeepers syncing");
self.sync_safekeepers(spec.storage_auth_token.clone())
.await
self.sync_safekeepers(pspec.storage_auth_token.clone())
.with_context(|| "failed to sync safekeepers")?
};
info!("safekeepers synced at LSN {}", lsn);
@@ -1143,13 +1248,13 @@ impl ComputeNode {
};
info!(
"getting basebackup@{} from pageserver {}:{}",
lsn, spec.pageservers[0].host, spec.pageservers[0].port
"getting basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
);
self.get_basebackup(compute_state, lsn).with_context(|| {
format!(
"failed to get basebackup@{} from pageserver {}:{}",
lsn, spec.pageservers[0].host, spec.pageservers[0].port
"failed to get basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
)
})?;
@@ -1431,9 +1536,10 @@ impl ComputeNode {
let conf = Arc::new(conf);
let spec = Arc::new(
compute_state
.spec
.pspec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
@@ -1502,7 +1608,7 @@ impl ComputeNode {
/// as it's used to reconfigure a previously started and configured Postgres node.
#[instrument(skip_all)]
pub fn reconfigure(&self) -> Result<()> {
let spec = self.state.lock().unwrap().spec.as_ref().unwrap().clone();
let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec;
let tls_config = self.tls_config(&spec);
@@ -1584,10 +1690,10 @@ impl ComputeNode {
#[instrument(skip_all)]
pub fn configure_as_primary(&self, compute_state: &ComputeState) -> Result<()> {
let spec = compute_state.spec.as_ref().expect("spec must be set");
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
assert!(spec.mode == ComputeMode::Primary);
if !spec.skip_pg_catalog_updates {
assert!(pspec.spec.mode == ComputeMode::Primary);
if !pspec.spec.skip_pg_catalog_updates {
let pgdata_path = Path::new(&self.params.pgdata);
// temporarily reset max_cluster_size in config
// to avoid the possibility of hitting the limit, while we are applying config:
@@ -2083,23 +2189,24 @@ LIMIT 100",
/// the pageserver connection strings has changed.
///
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageservers_unchanged(&self, duration: Duration) {
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageservers = state
.spec
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageservers
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let current_pageservers = &s.spec.as_ref().expect("spec must be set").pageservers;
unchanged = current_pageservers
.iter()
.zip(&old_pageservers)
.all(|(c, o)| c == o);
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();

View File

@@ -3,7 +3,6 @@ use anyhow::{Context, Result, bail};
use async_compression::tokio::bufread::{ZstdDecoder, ZstdEncoder};
use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use compute_api::spec::ComputeSpec;
use http::StatusCode;
use reqwest::Client;
use std::sync::Arc;
@@ -26,30 +25,24 @@ struct EndpointStoragePair {
}
const KEY: &str = "lfc_state";
impl TryFrom<&ComputeSpec> for EndpointStoragePair {
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(spec: &ComputeSpec) -> Result<Self, Self::Error> {
let Some(ref addr) = spec.endpoint_storage_addr else {
bail!("spec.endpoint_storage_addr missing")
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
};
let url = format!(
"http://{addr}/{tenant_id}/{timeline_id}/{endpoint_id}/{key}",
addr = addr,
tenant_id = spec.tenant_id,
timeline_id = spec.timeline_id,
endpoint_id = spec.endpoint_id,
key = KEY
);
let Some(ref token) = spec.endpoint_storage_token else {
bail!("spec.endpoint_storage_token missing")
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
};
let tenant_id = pspec.tenant_id;
let timeline_id = pspec.timeline_id;
Ok(EndpointStoragePair {
url,
token: token.clone(),
})
let url = format!("http://{base_uri}/{tenant_id}/{timeline_id}/{endpoint_id}/{KEY}");
let Some(ref token) = pspec.endpoint_storage_token else {
bail!("pspec.endpoint_storage_token missing")
};
let token = token.clone();
Ok(EndpointStoragePair { url, token })
}
}
@@ -118,7 +111,7 @@ impl ComputeNode {
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
state.spec.as_ref().unwrap().try_into()
state.pspec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self) -> Result<()> {

View File

@@ -56,24 +56,13 @@ pub fn write_postgres_conf(
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
if !spec.pageservers.is_empty() {
writeln!(
file,
"neon.pageserver_connstring={}",
escape_conf_value(
&spec
.pageservers
.iter()
.map(|p| format!("host={} port={}", p.host, p.port))
.collect::<Vec<_>>()
.join(",")
)
)?;
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeepers.is_empty() {
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(
"safekeepers_connstrings is not zero, gen: {:?}",
@@ -83,45 +72,32 @@ pub fn write_postgres_conf(
if let Some(generation) = spec.safekeepers_generation {
write!(neon_safekeepers_value, "g#{}:", generation)?;
}
neon_safekeepers_value.push_str(
&spec
.safekeepers
.iter()
.map(|s| format!("{}:{}", s.host.to_string(), s.port))
.collect::<Vec<_>>()
.join(","),
);
neon_safekeepers_value.push_str(&spec.safekeeper_connstrings.join(","));
writeln!(
file,
"neon.safekeepers={}",
escape_conf_value(&neon_safekeepers_value)
)?;
}
writeln!(
file,
"neon.tenant_id={}",
escape_conf_value(&spec.tenant_id.to_string())
)?;
writeln!(
file,
"neon.timeline_id={}",
escape_conf_value(&spec.timeline_id.to_string())
)?;
writeln!(
file,
"neon.project_id={}",
escape_conf_value(&spec.project_id)
)?;
writeln!(
file,
"neon.branch_id={}",
escape_conf_value(&spec.branch_id)
)?;
writeln!(
file,
"neon.endpoint_id={}",
escape_conf_value(&spec.endpoint_id)
)?;
if let Some(s) = &spec.tenant_id {
writeln!(file, "neon.tenant_id={}", escape_conf_value(&s.to_string()))?;
}
if let Some(s) = &spec.timeline_id {
writeln!(
file,
"neon.timeline_id={}",
escape_conf_value(&s.to_string())
)?;
}
if let Some(s) = &spec.project_id {
writeln!(file, "neon.project_id={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.branch_id {
writeln!(file, "neon.branch_id={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.endpoint_id {
writeln!(file, "neon.endpoint_id={}", escape_conf_value(s))?;
}
// tls
if let Some(tls_config) = tls_config {

View File

@@ -8,7 +8,7 @@ use http::StatusCode;
use tokio::task;
use tracing::info;
use crate::compute::ComputeNode;
use crate::compute::{ComputeNode, ParsedSpec};
use crate::http::JsonResponse;
use crate::http::extract::Json;
@@ -22,6 +22,11 @@ pub(in crate::http) async fn configure(
State(compute): State<Arc<ComputeNode>>,
request: Json<ConfigurationRequest>,
) -> Response {
let pspec = match ParsedSpec::try_from(request.0.spec) {
Ok(p) => p,
Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e),
};
// XXX: wrap state update under lock in a code block. Otherwise, we will try
// to `Send` `mut state` into the spawned thread bellow, which will cause
// the following rustc error:
@@ -38,7 +43,7 @@ pub(in crate::http) async fn configure(
// configure request for tracing purposes.
state.startup_span = Some(tracing::Span::current());
state.spec = Some(request.spec.clone());
state.pspec = Some(pspec);
state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed);
drop(state);
}

View File

@@ -31,7 +31,8 @@ pub(in crate::http) async fn download_extension(
let ext = {
let state = compute.state.lock().unwrap();
let spec = &state.spec.as_ref().unwrap();
let pspec = state.pspec.as_ref().unwrap();
let spec = &pspec.spec;
let remote_extensions = match spec.remote_extensions.as_ref() {
Some(r) => r,

View File

@@ -21,8 +21,14 @@ impl From<&ComputeState> for ComputeStatusResponse {
fn from(state: &ComputeState) -> Self {
ComputeStatusResponse {
start_time: state.start_time,
tenant: state.spec.as_ref().map(|spec| spec.tenant_id.to_string()),
timeline: state.spec.as_ref().map(|spec| spec.timeline_id.to_string()),
tenant: state
.pspec
.as_ref()
.map(|pspec| pspec.tenant_id.to_string()),
timeline: state
.pspec
.as_ref()
.map(|pspec| pspec.timeline_id.to_string()),
status: state.status,
last_active: state.last_active,
error: state.error.clone(),

View File

@@ -18,8 +18,8 @@ use crate::compute::ComputeNode;
pub fn launch_lsn_lease_bg_task_for_static(compute: &Arc<ComputeNode>) {
let (tenant_id, timeline_id, lsn) = {
let state = compute.state.lock().unwrap();
let spec = state.spec.as_ref().expect("Spec must be set");
match spec.mode {
let spec = state.pspec.as_ref().expect("Spec must be set");
match spec.spec.mode {
ComputeMode::Static(lsn) => (spec.tenant_id, spec.timeline_id, lsn),
_ => return,
}
@@ -58,7 +58,7 @@ fn lsn_lease_bg_task(
"Request succeeded, sleeping for {} seconds",
sleep_duration.as_secs()
);
compute.wait_timeout_while_pageservers_unchanged(sleep_duration);
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
}
}
@@ -79,11 +79,18 @@ fn acquire_lsn_lease_with_retry(
let configs = {
let state = compute.state.lock().unwrap();
let spec = state.spec.as_ref().expect("spec must be set");
let spec = state.pspec.as_ref().expect("spec must be set");
spec.pageservers
.iter()
.map(|p| postgres::Config::from(p))
let conn_strings = spec.pageserver_connstr.split(',');
conn_strings
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
config.password(storage_auth_token.clone());
}
config
})
.collect::<Vec<_>>()
};
@@ -98,7 +105,7 @@ fn acquire_lsn_lease_with_retry(
Err(e) => {
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
compute.wait_timeout_while_pageservers_unchanged(Duration::from_millis(
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
retry_period_ms as u64,
));
retry_period_ms *= 1.5;

View File

@@ -4,7 +4,7 @@ use std::future::Future;
use std::iter::{empty, once};
use std::sync::Arc;
use anyhow::Result;
use anyhow::{Context, Result};
use compute_api::responses::ComputeStatus;
use compute_api::spec::{ComputeAudit, ComputeSpec, Database, PgIdent, Role};
use futures::future::join_all;
@@ -74,7 +74,7 @@ impl ComputeNode {
let mut drop_subscriptions_done = false;
if spec.drop_subscriptions_before_start {
let timeline_id = self.get_timeline_id();
let timeline_id = self.get_timeline_id().context("timeline_id must be set")?;
info!("Checking if drop subscription operation was already performed for timeline_id: {}", timeline_id);

View File

@@ -37,7 +37,7 @@ pub async fn ping_safekeeper(
// Parse result
info!("done with {}", id);
if let tokio_postgres::SimpleQueryMessage::Row(row) = &result[0] {
if let postgres::SimpleQueryMessage::Row(row) = &result[0] {
use std::str::FromStr;
let response = TimelineStatusResponse::Ok(TimelineStatusOkResponse {
flush_lsn: Lsn::from_str(row.get("flush_lsn").unwrap())?,

View File

@@ -36,7 +36,6 @@ pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
safekeeper_client.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true
http-utils.workspace = true

View File

@@ -45,7 +45,7 @@ use pageserver_api::models::{
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::membership::SafekeeperGeneration;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -1255,45 +1255,6 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
pageserver
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
.await?;
if env.storage_controller.timelines_onto_safekeepers {
println!("Creating timeline on safekeeper ...");
let timeline_info = pageserver
.timeline_info(
TenantShardId::unsharded(tenant_id),
timeline_id,
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
)
.await?;
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
let default_host = default_sk
.conf
.listen_addr
.clone()
.unwrap_or_else(|| "localhost".to_string());
let mconf = safekeeper_api::membership::Configuration {
generation: SafekeeperGeneration::new(1),
members: safekeeper_api::membership::MemberSet {
m: vec![SafekeeperId {
host: default_host,
id: default_sk.conf.id,
pg_port: default_sk.conf.pg_port,
}],
},
new_members: None,
};
let pg_version = args.pg_version * 10000;
let req = safekeeper_api::models::TimelineCreateRequest {
tenant_id,
timeline_id,
mconf,
pg_version,
system_id: None,
wal_seg_size: None,
start_lsn: timeline_info.last_record_lsn,
commit_lsn: None,
};
default_sk.create_timeline(&req).await?;
}
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
println!("Done");
}
@@ -1493,10 +1454,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config");
(
vec![compute_api::spec::Pageserver {
host: parsed.0,
port: parsed.1.unwrap_or(5432),
}],
vec![(parsed.0, parsed.1.unwrap_or(5432))],
// If caller is telling us what pageserver to use, this is not a tenant which is
// full managed by storage controller, therefore not sharded.
DEFAULT_STRIPE_SIZE,
@@ -1519,11 +1477,11 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.await?;
}
anyhow::Ok(compute_api::spec::Pageserver {
host: Host::parse(&shard.listen_pg_addr)
anyhow::Ok((
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
port: shard.listen_pg_port,
})
shard.listen_pg_port,
))
}),
)
.await?;
@@ -1579,10 +1537,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?);
vec![compute_api::spec::Pageserver {
host: pageserver.pg_connection_config.host().clone(),
port: pageserver.pg_connection_config.port(),
}]
vec![(
pageserver.pg_connection_config.host().clone(),
pageserver.pg_connection_config.port(),
)]
} else {
let storage_controller = StorageController::from_env(env);
storage_controller
@@ -1590,10 +1548,12 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.await?
.shards
.into_iter()
.map(|shard| compute_api::spec::Pageserver {
host: Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
port: shard.listen_pg_port,
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>()
};

View File

@@ -52,8 +52,8 @@ use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, Pageserver, PgIdent,
RemoteExtSpec, Role, Safekeeper,
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent,
RemoteExtSpec, Role,
};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
@@ -606,25 +606,29 @@ impl Endpoint {
}
}
fn safekeepers_from_nodes(&self, ids: Vec<NodeId>) -> Result<Vec<Safekeeper>> {
let mut s = Vec::new();
fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String {
pageservers
.iter()
.map(|(host, port)| format!("postgresql://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
/// Map safekeepers ids to the actual connection strings.
fn build_safekeepers_connstrs(&self, sk_ids: Vec<NodeId>) -> Result<Vec<String>> {
let mut safekeeper_connstrings = Vec::new();
if self.mode == ComputeMode::Primary {
for id in ids {
for sk_id in sk_ids {
let sk = self
.env
.safekeepers
.iter()
.find(|node| node.id == id)
.ok_or_else(|| anyhow!("safekeeper {id} does not exist"))?;
s.push(Safekeeper {
host: Host::parse("127.0.0.1")?,
port: sk.get_compute_port(),
});
.find(|node| node.id == sk_id)
.ok_or_else(|| anyhow!("safekeeper {sk_id} does not exist"))?;
safekeeper_connstrings.push(format!("127.0.0.1:{}", sk.get_compute_port()));
}
}
Ok(s)
Ok(safekeeper_connstrings)
}
/// Generate a JWT with the correct claims.
@@ -650,7 +654,7 @@ impl Endpoint {
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<Pageserver>,
pageservers: Vec<(Host, u16)>,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
@@ -668,6 +672,11 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstring.is_empty());
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
// check for file remote_extensions_spec.json
// if it is present, read it and pass to compute_ctl
let remote_extensions_spec_path = self.endpoint_path().join("remote_extensions_spec.json");
@@ -718,34 +727,15 @@ impl Endpoint {
postgresql_conf: Some(postgresql_conf.clone()),
},
delta_operations: None,
tenant_id: self.tenant_id.clone(),
timeline_id: self.timeline_id.clone(),
project_id: self.tenant_id.to_string(),
branch_id: self.timeline_id.to_string(),
endpoint_id: self.endpoint_id.clone(),
tenant_id: Some(self.tenant_id),
timeline_id: Some(self.timeline_id),
project_id: None,
branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode,
pageservers,
safekeepers: {
let mut s = Vec::new();
if self.mode == ComputeMode::Primary {
for id in safekeepers {
let sk = self
.env
.safekeepers
.iter()
.find(|node| node.id == id)
.ok_or_else(|| anyhow!("safekeeper {id} does not exist"))?;
s.push(Safekeeper {
host: Host::parse("127.0.0.1")?,
port: sk.get_compute_port(),
});
}
}
s
},
pageserver_connstring: Some(pageserver_connstring),
safekeepers_generation: safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
@@ -949,7 +939,7 @@ impl Endpoint {
pub async fn reconfigure(
&self,
pageservers: Vec<Pageserver>,
mut pageservers: Vec<(Host, u16)>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
) -> Result<()> {
@@ -968,24 +958,30 @@ impl Endpoint {
if pageservers.is_empty() {
let storage_controller = StorageController::from_env(&self.env);
let locate_result = storage_controller.tenant_locate(self.tenant_id).await?;
spec.pageservers = locate_result
pageservers = locate_result
.shards
.into_iter()
.map(|shard| Pageserver {
host: Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
port: shard.listen_pg_port,
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>();
}
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstr.is_empty());
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
}
// If safekeepers are not specified, don't change them.
if let Some(safekeepers) = safekeepers {
spec.safekeepers = self.safekeepers_from_nodes(safekeepers)?;
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
spec.safekeeper_connstrings = safekeeper_connstrings;
}
let client = reqwest::Client::builder()

View File

@@ -635,16 +635,4 @@ impl PageServerNode {
Ok(())
}
pub async fn timeline_info(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
) -> anyhow::Result<TimelineInfo> {
let timeline_info = self
.http_client
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
.await?;
Ok(timeline_info)
}
}

View File

@@ -6,6 +6,7 @@
//! .neon/safekeepers/<safekeeper id>
//! ```
use std::error::Error as _;
use std::future::Future;
use std::io::Write;
use std::path::PathBuf;
use std::time::Duration;
@@ -13,9 +14,9 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use http_utils::error::HttpErrorBody;
use postgres_connection::PgConnectionConfig;
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::id::NodeId;
@@ -34,14 +35,25 @@ pub enum SafekeeperHttpError {
type Result<T> = result::Result<T, SafekeeperHttpError>;
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
use mgmt_api::Error::*;
match err {
ApiError(_, str) => SafekeeperHttpError::Response(str),
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
pub(crate) trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
}
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(SafekeeperHttpError::Response(
match self.json::<HttpErrorBody>().await {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
))
}
}
@@ -58,8 +70,9 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: mgmt_api::Client,
pub http_client: reqwest::Client,
pub listen_addr: String,
pub http_base_url: String,
}
impl SafekeeperNode {
@@ -69,14 +82,13 @@ impl SafekeeperNode {
} else {
"127.0.0.1".to_string()
};
let jwt = None;
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
env: env.clone(),
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
http_client: env.create_http_client(),
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
listen_addr,
}
}
@@ -266,19 +278,20 @@ impl SafekeeperNode {
)
}
pub async fn check_status(&self) -> Result<()> {
self.http_client
.status()
.await
.map_err(err_from_client_err)?;
Ok(())
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
// TODO: authentication
//if self.env.auth_type == AuthType::NeonJWT {
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
//}
self.http_client.request(method, url)
}
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
self.http_client
.create_timeline(req)
.await
.map_err(err_from_client_err)?;
pub async fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()
.await?
.error_from_body()
.await?;
Ok(())
}
}

View File

@@ -61,16 +61,10 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
// Set a node status as deleted.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Delete a tombstone of node from the storage controller.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
},
/// Modify a tenant's policies in the storage controller
TenantPolicy {
#[arg(long)]
@@ -88,8 +82,6 @@ enum Command {
},
/// List nodes known to the storage controller
Nodes {},
/// List soft deleted nodes known to the storage controller
NodeTombstones {},
/// List tenants known to the storage controller
Tenants {
/// If this field is set, it will list the tenants on a specific node
@@ -908,39 +900,6 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("debug/v1/tombstone/{node_id}"),
None,
)
.await?;
}
Command::NodeTombstones {} => {
let mut resp = storcon_client
.dispatch::<(), Vec<NodeDescribeResponse>>(
Method::GET,
"debug/v1/tombstone".to_string(),
None,
)
.await?;
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
let mut table = comfy_table::Table::new();
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
for node in resp {
table.add_row([
format!("{}", node.id),
node.listen_http_addr,
node.availability_zone_id,
format!("{:?}", node.scheduling),
format!("{:?}", node.availability),
]);
}
println!("{table}");
}
Command::TenantSetTimeBasedEviction {
tenant_id,
period,

View File

@@ -8,7 +8,6 @@ anyhow.workspace = true
axum-extra.workspace = true
axum.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
jsonwebtoken.workspace = true
prometheus.workspace = true

View File

@@ -4,8 +4,6 @@
//! for large computes.
mod app;
use anyhow::Context;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -14,26 +12,9 @@ const fn max_upload_file_limit() -> usize {
100 * 1024 * 1024
}
const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
config_file: Option<String>,
#[arg(long, default_value = "false", requires = "config")]
/// to allow testing k8s helm chart where we don't have s3 credentials
no_s3_check_on_startup: bool,
#[arg(long, value_name = "FILE")]
/// inline config mode for k8s helm chart
config: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
#[serde(default = "listen")]
listen: std::net::SocketAddr,
pemfile: camino::Utf8PathBuf,
#[serde(flatten)]
@@ -50,18 +31,13 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let args = Args::parse();
let config: Config = if let Some(config_path) = args.config_file {
info!("Reading config from {config_path}");
let config = std::fs::read_to_string(config_path)?;
serde_json::from_str(&config).context("parsing config")?
} else if let Some(config) = args.config {
info!("Reading inline config");
serde_json::from_str(&config).context("parsing config")?
} else {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: endpoint_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
let config: Config = serde_json::from_str(&config).context("parsing config")?;
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
@@ -72,9 +48,7 @@ async fn main() -> anyhow::Result<()> {
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
if !args.no_s3_check_on_startup {
app::check_storage_permissions(&storage, cancel.clone()).await?;
}
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,

View File

@@ -9,11 +9,8 @@ anyhow.workspace = true
chrono.workspace = true
indexmap.workspace = true
jsonwebtoken.workspace = true
postgres.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio-postgres.workspace = true
url.workspace = true
regex.workspace = true
utils = { path = "../utils" }

View File

@@ -9,8 +9,7 @@ use indexmap::IndexMap;
use regex::Regex;
use remote_storage::RemotePath;
use serde::{Deserialize, Serialize};
use url::Host;
use utils::id::{BranchId, EndpointId, ProjectId, TenantId, TimelineId};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::responses::TlsConfig;
@@ -22,77 +21,13 @@ pub type PgIdent = String;
/// String type alias representing Postgres extension version
pub type ExtVersion = String;
/// Pageserver settings.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Pageserver {
/// Hostname of the pageserver.
pub host: Host,
/// Port that the safekeeper listens on.
pub port: u16,
}
impl From<&Pageserver> for postgres::Config {
fn from(ps: &Pageserver) -> Self {
let mut config = postgres::Config::new();
config.host(&ps.host.to_string());
config.port(ps.port);
config
}
}
impl From<&Pageserver> for tokio_postgres::Config {
fn from(ps: &Pageserver) -> Self {
let mut config = tokio_postgres::Config::new();
config.host(&ps.host.to_string());
config.port(ps.port);
config
}
}
/// Safekeeper settings.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Safekeeper {
/// Hostname of the safekeeper.
pub host: Host,
/// Port that the safekeeper listens on.
pub port: u16,
}
impl From<&Safekeeper> for postgres::Config {
fn from(sk: &Safekeeper) -> Self {
let mut config = postgres::Config::new();
config.host(&sk.host.to_string());
config.port(sk.port);
config
}
}
impl From<&Safekeeper> for tokio_postgres::Config {
fn from(sk: &Safekeeper) -> Self {
let mut config = tokio_postgres::Config::new();
config.host(&sk.host.to_string());
config.port(sk.port);
config
}
}
fn default_reconfigure_concurrency() -> usize {
1
}
/// Cluster spec or configuration represented as an optional number of
/// delta operations + final cluster state description.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct ComputeSpec {
pub format_version: f32,
@@ -155,13 +90,25 @@ pub struct ComputeSpec {
// Information needed to connect to the storage layer.
//
// `tenant_id`, `timeline_id` and `pageserver_connstring` are always needed.
//
// Depending on `mode`, this can be a primary read-write node, a read-only
// replica, or a read-only node pinned at an older LSN.
// `safekeeper_connstrings` must be set for a primary.
pub pageservers: Vec<Pageserver>,
//
// For backwards compatibility, the control plane may leave out all of
// these, and instead set the "neon.tenant_id", "neon.timeline_id",
// etc. GUCs in cluster.settings. TODO: Once the control plane has been
// updated to fill these fields, we can make these non optional.
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
pub pageserver_connstring: Option<String>,
#[serde(default)]
pub safekeepers_generation: Option<u32>,
// More neon ids that we expose to the compute_ctl
// and to postgres as neon extension GUCs.
pub project_id: Option<String>,
pub branch_id: Option<String>,
pub endpoint_id: Option<String>,
/// Safekeeper membership config generation. It is put in
/// neon.safekeepers GUC and serves two purposes:
@@ -173,18 +120,9 @@ pub struct ComputeSpec {
/// Note: it could be SafekeeperGeneration, but this needs linking
/// compute_ctl with postgres_ffi.
#[serde(default)]
pub safekeepers: Vec<Safekeeper>,
/// The Neon tenant ID. Exposed to Postgres as `neon.tenant_id`.
pub tenant_id: TenantId,
/// The Neon timeline ID. Exposed to Postgres as `neon.timeline_id`.
pub timeline_id: TimelineId,
/// The Neon project ID. Exposed to Postgres as `neon.project_id`.
pub project_id: ProjectId,
/// The Neon branch ID. Exposed to Postgres as `neon.branch_id`.
pub branch_id: BranchId,
/// The Neon endpoint ID. Exposed to Postgres as `neon.endpoint_id`.
pub endpoint_id: EndpointId,
pub safekeepers_generation: Option<u32>,
#[serde(default)]
pub safekeeper_connstrings: Vec<String>,
#[serde(default)]
pub mode: ComputeMode,

View File

@@ -6,8 +6,12 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace=true
nix.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"

304
libs/neon-shmem/src/hash.rs Normal file
View File

@@ -0,0 +1,304 @@
//! Hash table implementation on top of 'shmem'
//!
//! Features required in the long run by the communicator project:
//!
//! [X] Accessible from both Postgres processes and rust threads in the communicator process
//! [X] Low latency
//! [ ] Scalable to lots of concurrent accesses (currently relies on caller for locking)
//! [ ] Resizable
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::CoreHashMap;
use entry::{Entry, OccupiedEntry};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub struct HashMapInit<'a, K, V> {
// Hash table can be allocated in a fixed memory area, or in a resizeable ShmemHandle.
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
pub struct HashMapAccess<'a, K, V> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
unsafe impl<'a, K: Sync, V: Sync> Sync for HashMapAccess<'a, K, V> {}
unsafe impl<'a, K: Send, V: Send> Send for HashMapAccess<'a, K, V> {}
impl<'a, K, V> HashMapInit<'a, K, V> {
pub fn attach_writer(self) -> HashMapAccess<'a, K, V> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
}
}
pub fn attach_reader(self) -> HashMapAccess<'a, K, V> {
// no difference to attach_writer currently
self.attach_writer()
}
}
/// This is stored in the shared memory area
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// HashMapShared
/// [buckets]
/// [dictionary]
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
struct HashMapShared<'a, K, V> {
inner: CoreHashMap<'a, K, V>,
}
impl<'a, K, V> HashMapInit<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
pub fn init_in_fixed_area(
num_buckets: u32,
area: &'a mut [MaybeUninit<u8>],
) -> HashMapInit<'a, K, V> {
Self::init_common(num_buckets, None, area.as_mut_ptr().cast(), area.len())
}
/// Initialize a new hash map in the given shared memory area
pub fn init_in_shmem(num_buckets: u32, mut shmem: ShmemHandle) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = unsafe { shmem.data_ptr.as_mut() };
Self::init_common(num_buckets, Some(shmem), ptr, size)
}
fn init_common(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_len: usize,
) -> HashMapInit<'a, K, V> {
// carve out the HashMapShared struct from the area.
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { area_ptr.add(area_len) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
}
HashMapInit {
shmem_handle: shmem_handle,
shared_ptr,
}
}
}
impl<'a, K, V> HashMapAccess<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn get_hash_value(&self, key: &K) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_with_hash(key, hash)
}
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_with_hash(key, hash)
}
pub fn remove_with_hash(&mut self, key: &K, hash: u64) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
match map.inner.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => {
e.remove();
}
Entry::Vacant(_) => {}
};
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_at_bucket(pos)
}
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map. (An Iterator might be nicer. The communicator's
/// clock algorithm needs to _slowly_ iterate through all buckets with its clock hand,
/// without holding a lock. If we switch to an Iterator, it must not hold the lock.)
pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
if pos >= map.inner.buckets.len() {
return None;
}
let bucket = &map.inner.buckets[pos];
bucket.inner.as_ref()
}
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let origin = map.inner.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / (size_of::<V>() as usize);
assert!(idx < map.inner.buckets.len());
idx
}
// for metrics
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.buckets_in_use as usize
}
/// Grow
///
/// 1. grow the underlying shared memory area
/// 2. Initialize new buckets. This overwrites the current dictionary
/// 3. Recalculate the dictionary
pub fn grow(&mut self, num_buckets: u32) -> Result<(), crate::shmem::Error> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let old_num_buckets = inner.buckets.len() as u32;
if num_buckets < old_num_buckets {
panic!("grow called with a smaller number of buckets");
}
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites
// the dictionary!
let buckets_ptr = inner.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket_ptr = buckets_ptr.add(i as usize);
bucket_ptr.write(core::Bucket {
next: if i < num_buckets {
i as u32 + 1
} else {
inner.free_head
},
inner: None,
});
}
}
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = core::INVALID_POS;
}
for i in 0..old_num_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let mut hasher = DefaultHasher::new();
buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher);
let hash = hasher.finish();
let pos: usize = (hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
inner.free_head = old_num_buckets;
Ok(())
}
// TODO: Shrinking is a multi-step process that requires co-operation from the caller
//
// 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered
// buckets.
//
// 2. Next, the caller must evict all entries in higher-numbered buckets.
//
// 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying
// shmem area
}

View File

@@ -0,0 +1,174 @@
//! Simple hash table with chaining
//!
//! # Resizing
//!
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
pub(crate) const INVALID_POS: u32 = u32::MAX;
// Bucket
pub(crate) struct Bucket<K, V> {
pub(crate) next: u32,
pub(crate) inner: Option<(K, V)>,
}
pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) dictionary: &'a mut [u32],
pub(crate) buckets: &'a mut [Bucket<K, V>],
pub(crate) free_head: u32,
pub(crate) _user_list_head: u32,
// metrics
pub(crate) buckets_in_use: u32,
}
#[derive(Debug)]
pub struct FullError();
impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V>
where
K: Clone + Hash + Eq,
{
const FILL_FACTOR: f32 = 0.60;
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> CoreHashMap<'a, K, V> {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for i in 0..dictionary.len() {
dictionary[i].write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
CoreHashMap {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
_user_list_head: INVALID_POS,
}
}
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(&bucket_value);
}
next = bucket.next;
}
}
// all updates are done through Entry
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let dict_pos = hash as usize % self.dictionary.len();
let first = self.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut self.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map: self,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<K, V>> {
if pos >= self.buckets.len() {
return None;
}
todo!()
//self.buckets[pos].inner.as_ref()
}
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let pos = self.free_head;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.buckets[pos as usize];
self.free_head = bucket.next;
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
return Ok(pos);
}
}

View File

@@ -0,0 +1,91 @@
//! Like std::collections::hash_map::Entry;
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
pub(crate) enum PrevPos {
First(u32),
Chained(u32),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) _key: K, // The key of the occupied entry
pub(crate) prev_pos: PrevPos,
pub(crate) bucket_pos: u32, // The position of the bucket in the CoreHashMap's buckets array
}
impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
let old_value = mem::replace(&mut bucket.inner.as_mut().unwrap().1, value);
old_value
}
pub fn remove(self) -> V {
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match self.prev_pos {
PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = bucket.next,
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next
}
}
// and add it to the freelist
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = self.map.free_head;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
return old_value.unwrap().1;
}
}
pub struct VacantEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) key: K, // The key to insert
pub(crate) dict_pos: u32,
}
impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> {
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.map.buckets[pos as usize];
bucket.next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1;
return Ok(result);
}
}

View File

@@ -0,0 +1,220 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::UpdateAction;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let w = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
//eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
sut: &HashMapAccess<TestKey, TestValue>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
sut.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
#[test]
fn random_ops() {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(100000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}
#[test]
fn test_grow() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(1000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1000) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
writer.grow(1500).unwrap();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1500) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -1,418 +1,4 @@
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}
pub mod hash;
pub mod shmem;

View File

@@ -0,0 +1,418 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -344,35 +344,6 @@ impl Default for ShardSchedulingPolicy {
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeLifecycle {
Active,
Deleted,
}
impl FromStr for NodeLifecycle {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"deleted" => Ok(Self::Deleted),
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
}
}
}
impl From<NodeLifecycle> for String {
fn from(value: NodeLifecycle) -> String {
use NodeLifecycle::*;
match value {
Active => "active",
Deleted => "deleted",
}
.to_string()
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeSchedulingPolicy {
Active,

View File

@@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
tls: T,
mut tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>

View File

@@ -17,6 +17,7 @@ use crate::{Client, Connection, Error};
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
@@ -230,7 +231,7 @@ impl Config {
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: &T,
tls: T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,

View File

@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
pub async fn connect<T>(
tls: &T,
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where

View File

@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
use crate::membership::Configuration;
use crate::{ServerInfo, Term};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize)]
pub struct SafekeeperStatus {
pub id: NodeId,
}

View File

@@ -295,11 +295,7 @@ pub struct TenantId(Id);
id_newtype!(TenantId);
/// Type representing a project ID.
pub type ProjectId = String;
/// Type representing a branch ID.
pub type BranchId = String;
/// Type representing an endpoint ID.
/// If needed, reuse small string from proxy/src/types.rc
pub type EndpointId = String;
// A pair uniquely identifying Neon instance.

View File

@@ -23,7 +23,6 @@ use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::page_service::GrpcPageServiceHandler;
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
};
@@ -815,7 +814,7 @@ fn start_pageserver(
// necessary?
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
page_service_grpc = Some(page_service::spawn_grpc(
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),

View File

@@ -1053,15 +1053,6 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("Failed to register pageserver_tenant_states_count metric")
});
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_timeline_states_count",
"Count of timelines per state",
&["state"]
)
.expect("Failed to register pageserver_timeline_states_count metric")
});
/// A set of broken tenants.
///
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
@@ -3334,8 +3325,6 @@ impl TimelineMetrics {
&timeline_id,
);
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
TimelineMetrics {
tenant_id,
shard_id,
@@ -3490,8 +3479,6 @@ impl TimelineMetrics {
return;
}
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
let tenant_id = &self.tenant_id;
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;

View File

@@ -169,6 +169,99 @@ pub fn spawn(
Listener { cancel, task }
}
/// Spawns a gRPC server for the page service.
///
/// TODO: move this onto GrpcPageServiceHandler::spawn().
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn_grpc(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
impl Listener {
pub async fn stop_accepting(self) -> Connections {
self.cancel.cancel();
@@ -3273,101 +3366,6 @@ pub struct GrpcPageServiceHandler {
}
impl GrpcPageServiceHandler {
/// Spawns a gRPC server for the page service.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
listener,
)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
// Run the page service.
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]

View File

@@ -89,8 +89,7 @@ use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
remove_tenant_metrics,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -545,28 +544,6 @@ pub struct OffloadedTimeline {
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
pub deleted_from_ancestor: AtomicBool,
_metrics_guard: OffloadedTimelineMetricsGuard,
}
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
struct OffloadedTimelineMetricsGuard;
impl OffloadedTimelineMetricsGuard {
fn new() -> Self {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.inc();
Self
}
}
impl Drop for OffloadedTimelineMetricsGuard {
fn drop(&mut self) {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.dec();
}
}
impl OffloadedTimeline {
@@ -599,8 +576,6 @@ impl OffloadedTimeline {
delete_progress: timeline.delete_progress.clone(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
})
}
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
@@ -620,7 +595,6 @@ impl OffloadedTimeline {
archived_at,
delete_progress: TimelineDeleteProgress::default(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
}
}
fn manifest(&self) -> OffloadedTimelineManifest {

View File

@@ -1671,12 +1671,7 @@ impl TenantManager {
}
}
// Phase 5: Shut down the parent shard. We leave it on disk in case the split fails and we
// have to roll back to the parent shard, avoiding a cold start. It will be cleaned up once
// the storage controller commits the split, or if all else fails, on the next restart.
//
// TODO: We don't flush the ephemeral layer here, because the split is likely to succeed and
// catching up the parent should be reasonably quick. Consider using FreezeAndFlush instead.
// Phase 5: Shut down the parent shard, and erase it from disk
let (_guard, progress) = completion::channel();
match parent.shutdown(progress, ShutdownMode::Hard).await {
Ok(()) => {}
@@ -1684,6 +1679,11 @@ impl TenantManager {
other.wait().await;
}
}
let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id);
let tmp_path = safe_rename_tenant_dir(&local_tenant_directory)
.await
.with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?;
self.background_purges.spawn(tmp_path);
fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!(
"failpoint"
@@ -1846,70 +1846,42 @@ impl TenantManager {
shutdown_all_tenants0(self.tenants).await
}
/// Detaches a tenant, and removes its local files asynchronously.
///
/// File removal is idempotent: even if the tenant has already been removed, this will still
/// remove any local files. This is used during shard splits, where we leave the parent shard's
/// files around in case we have to roll back the split.
pub(crate) async fn detach_tenant(
&self,
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
deletion_queue_client: &DeletionQueueClient,
) -> Result<(), TenantStateError> {
if let Some(tmp_path) = self
let tmp_path = self
.detach_tenant0(conf, tenant_shard_id, deletion_queue_client)
.await?
{
self.background_purges.spawn(tmp_path);
}
.await?;
self.background_purges.spawn(tmp_path);
Ok(())
}
/// Detaches a tenant. This renames the tenant directory to a temporary path and returns it,
/// allowing the caller to delete it asynchronously. Returns None if the dir is already removed.
async fn detach_tenant0(
&self,
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
deletion_queue_client: &DeletionQueueClient,
) -> Result<Option<Utf8PathBuf>, TenantStateError> {
) -> Result<Utf8PathBuf, TenantStateError> {
let tenant_dir_rename_operation = |tenant_id_to_clean: TenantShardId| async move {
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
if !tokio::fs::try_exists(&local_tenant_directory).await? {
// If the tenant directory doesn't exist, it's already cleaned up.
return Ok(None);
}
safe_rename_tenant_dir(&local_tenant_directory)
.await
.with_context(|| {
format!("local tenant directory {local_tenant_directory:?} rename")
})
.map(Some)
};
let mut removal_result = remove_tenant_from_memory(
let removal_result = remove_tenant_from_memory(
self.tenants,
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
// directory on disk anyway. For example, during shard splits, we shut down and remove the
// parent shard, but leave its directory on disk in case we have to roll back the split.
//
// TODO: it would be better to leave the parent shard attached until the split is committed.
// This will be needed by the gRPC page service too, such that a compute can continue to
// read from the parent shard until it's notified about the new child shards. See:
// <https://github.com/neondatabase/neon/issues/11728>.
if let Err(TenantStateError::SlotError(TenantSlotError::NotFound(_))) = removal_result {
removal_result = tenant_dir_rename_operation(tenant_shard_id)
.await
.map_err(TenantStateError::Other);
}
// Flush pending deletions, so that they have a good chance of passing validation
// before this tenant is potentially re-attached elsewhere.
deletion_queue_client.flush_advisory();

View File

@@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
/// Argument to [`Timeline::shutdown`].
#[derive(Debug, Clone, Copy)]
pub(crate) enum ShutdownMode {
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
/// take multiple seconds for a busy timeline.
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
///
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
/// the call to [`Timeline::shutdown`].

View File

@@ -1,6 +1,5 @@
# pgxs/neon/Makefile
MODULE_big = neon
OBJS = \
$(WIN32RES) \
@@ -22,7 +21,8 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
control_plane_connector.o \
walsender_hooks.o
walsender_hooks.o \
$(LIBCOMMUNICATOR_PATH)/libcommunicator.a
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)

View File

@@ -0,0 +1,13 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
[lib]
crate-type = ["staticlib"]
[dependencies]
neon-shmem.workspace = true
[build-dependencies]
cbindgen.workspace = true

View File

@@ -0,0 +1,8 @@
This package will evolve into a "compute-pageserver communicator"
process and machinery. For now, it just provides wrappers on the
neon-shmem Rust crate, to allow using it in the C implementation of
the LFC.
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -0,0 +1,22 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {:?}", e),
},
|bindings| {
bindings.write_to_file("communicator_bindings.h");
},
);
Ok(())
}

View File

@@ -0,0 +1,4 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -0,0 +1,240 @@
//! Glue code to allow using the Rust shmem hash map implementation from C code
//!
//! For convience of adapting existing code, the interface provided somewhat resembles the dynahash
//! interface.
//!
//! NOTE: The caller is responsible for locking! The caller is expected to hold the PostgreSQL
//! LWLock, 'lfc_lock', while accessing the hash table, in shared or exclusive mode as appropriate.
use std::ffi::c_void;
use std::marker::PhantomData;
use neon_shmem::hash::entry::Entry;
use neon_shmem::hash::{HashMapAccess, HashMapInit};
use neon_shmem::shmem::ShmemHandle;
/// NB: This must match the definition of BufferTag in Postgres C headers. We could use bindgen to
/// generate this from the C headers, but prefer to not introduce dependency on bindgen for now.
///
/// Note that there are no padding bytes. If the corresponding C struct has padding bytes, the C C
/// code must clear them.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
/// Like with FileCacheKey, this must match the definition of FileCacheEntry in file_cache.c. We
/// don't look at the contents here though, it's sufficent that the size and alignment matches.
#[derive(Clone, Debug, Default)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
/// XXX: This could be just:
///
/// ```ignore
/// type FileCacheHashMapHandle = HashMapInit<'a, FileCacheKey, FileCacheEntry>
/// ```
///
/// but with that, cbindgen generates a broken typedef in the C header file which doesn't
/// compile. It apparently gets confused by the generics.
#[repr(transparent)]
pub struct FileCacheHashMapHandle<'a>(
pub *mut c_void,
PhantomData<HashMapInit<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapHandle<'a> {
fn from(x: Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>) -> Self {
FileCacheHashMapHandle(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> From<FileCacheHashMapHandle<'a>> for Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>> {
fn from(x: FileCacheHashMapHandle) -> Self {
unsafe { Box::from_raw(x.0.cast()) }
}
}
/// XXX: same for this
#[repr(transparent)]
pub struct FileCacheHashMapAccess<'a>(
pub *mut c_void,
PhantomData<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapAccess<'a> {
fn from(x: Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>) -> Self {
// Convert the Box into a raw mutable pointer to the HashMapAccess itself.
// This transfers ownership of the HashMapAccess (and its contained ShmemHandle)
// to the raw pointer. The C caller is now responsible for managing this memory.
FileCacheHashMapAccess(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> FileCacheHashMapAccess<'a> {
fn as_ref(self) -> &'a HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_ref().unwrap() }
}
fn as_mut(self) -> &'a mut HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_mut().unwrap() }
}
}
/// Initialize the shared memory area at postmaster startup. The returned handle is inherited
/// by all the backend processes across fork()
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_init<'a>(
initial_num_buckets: u32,
max_num_buckets: u32,
) -> FileCacheHashMapHandle<'a> {
let max_bytes = HashMapInit::<FileCacheKey, FileCacheEntry>::estimate_size(max_num_buckets);
let shmem_handle =
ShmemHandle::new("lfc mapping", 0, max_bytes).expect("shmem initialization failed");
let handle = HashMapInit::<FileCacheKey, FileCacheEntry>::init_in_shmem(
initial_num_buckets,
shmem_handle,
);
Box::new(handle).into()
}
/// Initialize the access to the shared memory area in a backend process.
///
/// XXX: I'm not sure if this actually gets called in each process, or if the returned struct
/// is also inherited across fork(). It currently works either way but if this did more
/// initialization that needed to be done after fork(), then it would matter.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_access<'a>(
handle: FileCacheHashMapHandle<'a>,
) -> FileCacheHashMapAccess<'a> {
let handle: Box<HashMapInit<'_, FileCacheKey, FileCacheEntry>> = handle.into();
Box::new(handle.attach_writer()).into()
}
/// Return the current number of buckets in the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_num_buckets<'a>(
map: FileCacheHashMapAccess<'static>,
) -> u32 {
let map = map.as_ref();
map.get_num_buckets().try_into().unwrap()
}
/// Look up the entry with given key and hash.
///
/// This is similar to dynahash's hash_search(... , HASH_FIND)
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_find<'a>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
hash: u64,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_with_hash(key, hash)
}
/// Look up the entry at given bucket position
///
/// This has no direct equivalent in the dynahash interface, but can be used to
/// iterate through all entries in the hash table.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_at_pos<'a>(
map: FileCacheHashMapAccess<'static>,
pos: u32,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_at_bucket(pos as usize).map(|(_k, v)| v)
}
/// Remove entry, given a pointer to the value.
///
/// This is equivalent to dynahash hash_search(entry->key, HASH_REMOVE), where 'entry'
/// is an entry you have previously looked up
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_remove_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *mut FileCacheEntry,
) {
let map = map.as_mut();
let pos = map.get_bucket_for_value(entry);
match map.entry_at_bucket(pos) {
Some(e) => {
e.remove();
}
None => {
// todo: shouldn't happen, panic?
}
}
}
/// Compute the hash for given key
///
/// This is equivalent to dynahash get_hash_value() function. We use Rust's default hasher
/// for calculating the hash though.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_hash_value<'a, 'b>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
) -> u64 {
map.as_ref().get_hash_value(key)
}
/// Insert a new entry to the hash table
///
/// This is equivalent to dynahash hash_search(..., HASH_ENTER).
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_enter<'a, 'b>(
map: FileCacheHashMapAccess,
key: &FileCacheKey,
hash: u64,
found: &mut bool,
) -> *mut FileCacheEntry {
match map.as_mut().entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => {
*found = true;
e.get_mut()
}
Entry::Vacant(e) => {
*found = false;
let initial_value = FileCacheEntry::default();
e.insert(initial_value).expect("TODO: hash table full")
}
}
}
/// Get the key for a given entry, which must be present in the hash table.
///
/// Dynahash requires the key to be part of the "value" struct, so you can always
/// access the key with something like `entry->key`. The Rust implementation however
/// stores the key separately. This function extracts the separately stored key.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_key_for_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *const FileCacheEntry,
) -> Option<&FileCacheKey> {
let map = map.as_ref();
let pos = map.get_bucket_for_value(entry);
map.get_at_bucket(pos as usize).map(|(k, _v)| k)
}
/// Remove all entries from the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_reset<'a, 'b>(map: FileCacheHashMapAccess) {
let map = map.as_mut();
let num_buckets = map.get_num_buckets();
for i in 0..num_buckets {
if let Some(e) = map.entry_at_bucket(i) {
e.remove();
}
}
}

View File

@@ -0,0 +1 @@
pub mod file_cache_hashmap;

View File

@@ -21,7 +21,7 @@
#include "access/xlog.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "common/hashfn.h"
#include "common/file_utils.h"
#include "pgstat.h"
#include "port/pg_iovec.h"
#include "postmaster/bgworker.h"
@@ -36,7 +36,6 @@
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
@@ -46,6 +45,7 @@
#include "hll.h"
#include "bitmap.h"
#include "file_cache.h"
#include "file_cache_rust_hash.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "neon_perf_counters.h"
@@ -64,7 +64,7 @@
*
* Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about
* its consistency.
*
*
* ## Holes
*
@@ -76,13 +76,15 @@
* fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't
* shrink, but the disk space it uses does.
*
* Each hole is tracked by a dummy FileCacheEntry, which are kept in the
* 'holes' linked list. They are entered into the chunk hash table, with a
* special key where the blockNumber is used to store the 'offset' of the
* hole, and all other fields are zero. Holes are never looked up in the hash
* table, we only enter them there to have a FileCacheEntry that we can keep
* in the linked list. If the soft limit is raised again, we reuse the holes
* before extending the nominal size of the file.
* Each hole is tracked in a freelist. The freelist consists of two parts: a
* fixed-size array in shared memory, and a linked chain of on-disk
* blocks. When the in-memory array fills up, it's flushed to a new on-disk
* chunk. If the soft limit is raised again, we reuse the holes before
* extending the nominal size of the file.
*
* The in-memory freelist array is protected by 'lfc_lock', while the on-disk
* chain is protected by a separate 'lfc_freelist_lock'. Locking rule to
* avoid deadlocks: always acquire lfc_freelist_lock first, then lfc_lock.
*/
/* Local file storage allocation chunk.
@@ -92,13 +94,15 @@
* 1Mb chunks can reduce hash map size to 320Mb.
* 2. Improve access locality, subsequent pages will be allocated together improving seqscan speed
*/
#define MAX_BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define MAX_BLOCKS_PER_CHUNK (1 << MAX_BLOCKS_PER_CHUNK_LOG)
#define BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define BLOCKS_PER_CHUNK (1 << BLOCKS_PER_CHUNK_LOG)
#define MB ((uint64)1024*1024)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> lfc_chunk_size_log))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (lfc_blocks_per_chunk-1))
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> BLOCKS_PER_CHUNK_LOG))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (BLOCKS_PER_CHUNK-1))
#define INVALID_OFFSET (0xffffffff)
/*
* Blocks are read or written to LFC file outside LFC critical section.
@@ -119,15 +123,18 @@ typedef enum FileCacheBlockState
typedef struct FileCacheEntry
{
BufferTag key;
uint32 hash;
uint32 offset;
uint32 access_count;
dlist_node list_node; /* LRU/holes list node */
uint32 state[FLEXIBLE_ARRAY_MEMBER]; /* two bits per block */
dlist_node list_node; /* LRU list node */
uint32 state[(BLOCKS_PER_CHUNK * 2 + 31) / 32]; /* two bits per block */
} FileCacheEntry;
#define FILE_CACHE_ENRTY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
/* Todo: alignment must be the same too */
StaticAssertDecl(sizeof(FileCacheEntry) == sizeof(RustFileCacheEntry),
"Rust and C declarations of FileCacheEntry are incompatible");
StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
"Rust and C declarations of FileCacheKey are incompatible");
#define GET_STATE(entry, i) (((entry)->state[(i) / 16] >> ((i) % 16 * 2)) & 3)
#define SET_STATE(entry, i, new_state) (entry)->state[(i) / 16] = ((entry)->state[(i) / 16] & ~(3 << ((i) % 16 * 2))) | ((new_state) << ((i) % 16 * 2))
@@ -136,6 +143,9 @@ typedef struct FileCacheEntry
#define MAX_PREWARM_WORKERS 8
#define FREELIST_ENTRIES_PER_CHUNK (BLOCKS_PER_CHUNK * BLCKSZ / sizeof(uint32) - 2)
typedef struct PrewarmWorkerState
{
uint32 prewarmed_pages;
@@ -161,7 +171,6 @@ typedef struct FileCacheControl
uint64 evicted_pages; /* number of evicted pages */
dlist_head lru; /* double linked list for LRU replacement
* algorithm */
dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
@@ -172,23 +181,39 @@ typedef struct FileCacheControl
bool prewarm_active;
bool prewarm_canceled;
dsm_handle prewarm_lfc_state_handle;
/*
* Free list. This is large enough to hold one chunks worth of entries.
*/
uint32 freelist_size;
uint32 freelist_head;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FileCacheControl;
typedef struct FreeListChunk
{
uint32 next;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FreeListChunk;
#define FILE_CACHE_STATE_MAGIC 0xfcfcfcfc
#define FILE_CACHE_STATE_BITMAP(fcs) ((uint8*)&(fcs)->chunks[(fcs)->n_chunks])
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * lfc_blocks_per_chunk)+7)/8)
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * BLOCKS_PER_CHUNK)+7)/8)
#define FILE_CACHE_STATE_SIZE(fcs) (sizeof(FileCacheState) + (fcs->n_chunks)*sizeof(BufferTag) + (((fcs->n_chunks) << fcs->chunk_size_log)+7)/8)
static HTAB *lfc_hash;
static FileCacheHashMapHandle lfc_hash_handle;
static FileCacheHashMapAccess lfc_hash;
static int lfc_desc = -1;
static LWLockId lfc_lock;
static LWLockId lfc_freelist_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_prewarm_limit;
static int lfc_prewarm_batch;
static int lfc_chunk_size_log = MAX_BLOCKS_PER_CHUNK_LOG;
static int lfc_blocks_per_chunk = MAX_BLOCKS_PER_CHUNK;
static int lfc_blocks_per_chunk_ro = BLOCKS_PER_CHUNK;
static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
@@ -205,6 +230,11 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
static bool freelist_push(uint32 offset);
static bool freelist_prepare_pop(void);
static uint32 freelist_pop(void);
static bool freelist_is_empty(void);
/*
* Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled.
@@ -232,15 +262,9 @@ lfc_switch_off(void)
if (LFC_ENABLED())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
/* Invalidate hash */
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search_with_hash_value(lfc_hash, &entry->key, entry->hash, HASH_REMOVE, NULL);
}
file_cache_hash_reset(lfc_hash);
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->pinned = 0;
@@ -248,7 +272,9 @@ lfc_switch_off(void)
lfc_ctl->used_pages = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
dlist_init(&lfc_ctl->holes);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
/*
* We need to use unlink to to avoid races in LFC write, because it is not
@@ -317,8 +343,8 @@ lfc_ensure_opened(void)
static void
lfc_shmem_startup(void)
{
size_t size;
bool found;
static HASHCTL info;
if (prev_shmem_startup_hook)
{
@@ -327,27 +353,29 @@ lfc_shmem_startup(void)
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
size = sizeof(FileCacheControl);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", size, &found);
if (!found)
{
int fd;
uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock");
info.keysize = sizeof(BufferTag);
info.entrysize = FILE_CACHE_ENRTY_SIZE;
lfc_freelist_lock = (LWLockId) GetNamedLWLockTranche("lfc_freelist_lock");
/*
* n_chunks+1 because we add new element to hash table before eviction
* of victim
*/
lfc_hash = ShmemInitHash("lfc_hash",
n_chunks + 1, n_chunks + 1,
&info,
HASH_ELEM | HASH_BLOBS);
memset(lfc_ctl, 0, sizeof(FileCacheControl));
lfc_hash_handle = file_cache_hash_shmem_init(n_chunks + 1, n_chunks + 1);
memset(lfc_ctl, 0, offsetof(FileCacheControl, free_pages));
dlist_init(&lfc_ctl->lru);
dlist_init(&lfc_ctl->holes);
lfc_ctl->freelist_size = FREELIST_ENTRIES_PER_CHUNK;
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
/* Initialize hyper-log-log structure for estimating working set size */
initSHLL(&lfc_ctl->wss_estimation);
@@ -371,18 +399,25 @@ lfc_shmem_startup(void)
}
LWLockRelease(AddinShmemInitLock);
lfc_hash = file_cache_hash_shmem_access(lfc_hash_handle);
}
static void
lfc_shmem_request(void)
{
size_t size;
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
size = sizeof(FileCacheControl);
RequestAddinShmemSpace(size);
RequestNamedLWLockTranche("lfc_lock", 1);
RequestNamedLWLockTranche("lfc_freelist_lock", 2);
}
static bool
@@ -398,24 +433,6 @@ is_normal_backend(void)
return lfc_ctl && MyProc && UsedShmemSegAddr && !IsParallelWorker();
}
static bool
lfc_check_chunk_size(int *newval, void **extra, GucSource source)
{
if (*newval & (*newval - 1))
{
elog(ERROR, "LFC chunk size should be power of two");
return false;
}
return true;
}
static void
lfc_change_chunk_size(int newval, void* extra)
{
lfc_chunk_size_log = pg_ceil_log2_32(newval);
}
static bool
lfc_check_limit_hook(int *newval, void **extra, GucSource source)
{
@@ -435,12 +452,14 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ctl || !is_normal_backend())
return;
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
/* Open LFC file only if LFC was enabled or we are going to reenable it */
if (newval == 0 && !LFC_ENABLED())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
/* File should be reopened if LFC is reenabled */
lfc_close_file();
return;
@@ -449,6 +468,7 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ensure_opened())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
@@ -464,35 +484,30 @@ lfc_change_limit_hook(int newval, void *extra)
* returning their space to file system
*/
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
FileCacheEntry *hole;
uint32 offset = victim->offset;
uint32 hash;
bool found;
BufferTag holetag;
CriticalAssert(victim->access_count == 0);
#ifdef FALLOC_FL_PUNCH_HOLE
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * lfc_blocks_per_chunk * BLCKSZ, lfc_blocks_per_chunk * BLCKSZ) < 0)
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * BLOCKS_PER_CHUNK * BLCKSZ, BLOCKS_PER_CHUNK * BLCKSZ) < 0)
neon_log(LOG, "Failed to punch hole in file: %m");
#endif
/* We remove the old entry, and re-enter a hole to the hash table */
for (int i = 0; i < lfc_blocks_per_chunk; i++)
/* We remove the entry, and enter a hole to the freelist */
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
lfc_ctl->evicted_pages += is_page_cached;
}
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, victim);
memset(&holetag, 0, sizeof(holetag));
holetag.blockNum = offset;
hash = get_hash_value(lfc_hash, &holetag);
hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found);
hole->hash = hash;
hole->offset = offset;
hole->access_count = 0;
CriticalAssert(!found);
dlist_push_tail(&lfc_ctl->holes, &hole->list_node);
if (!freelist_push(offset))
{
/* freelist_push already logged the error */
lfc_switch_off();
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
lfc_ctl->used -= 1;
}
@@ -504,6 +519,7 @@ lfc_change_limit_hook(int newval, void *extra)
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
}
void
@@ -579,14 +595,14 @@ lfc_init(void)
DefineCustomIntVariable("neon.file_cache_chunk_size",
"LFC chunk size in blocks (should be power of two)",
NULL,
&lfc_blocks_per_chunk,
MAX_BLOCKS_PER_CHUNK,
1,
MAX_BLOCKS_PER_CHUNK,
PGC_POSTMASTER,
&lfc_blocks_per_chunk_ro,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
PGC_INTERNAL,
GUC_UNIT_BLOCKS,
lfc_check_chunk_size,
lfc_change_chunk_size,
NULL,
NULL,
NULL);
DefineCustomIntVariable("neon.file_cache_prewarm_limit",
@@ -649,19 +665,19 @@ lfc_get_state(size_t max_entries)
fcs = (FileCacheState*)palloc0(state_size);
SET_VARSIZE(fcs, state_size);
fcs->magic = FILE_CACHE_STATE_MAGIC;
fcs->chunk_size_log = lfc_chunk_size_log;
fcs->chunk_size_log = BLOCKS_PER_CHUNK_LOG;
fcs->n_chunks = n_entries;
bitmap = FILE_CACHE_STATE_BITMAP(fcs);
dlist_reverse_foreach(iter, &lfc_ctl->lru)
{
FileCacheEntry *entry = dlist_container(FileCacheEntry, list_node, iter.cur);
fcs->chunks[i] = entry->key;
for (int j = 0; j < lfc_blocks_per_chunk; j++)
fcs->chunks[i] = *file_cache_hash_get_key_for_entry(lfc_hash, entry);
for (int j = 0; j < BLOCKS_PER_CHUNK; j++)
{
if (GET_STATE(entry, j) != UNAVAILABLE)
{
BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j);
BITMAP_SET(bitmap, i*BLOCKS_PER_CHUNK + j);
n_pages += 1;
}
}
@@ -670,7 +686,7 @@ lfc_get_state(size_t max_entries)
}
Assert(i == n_entries);
fcs->n_pages = n_pages;
Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages);
Assert(pg_popcount((char*)bitmap, ((n_entries << BLOCKS_PER_CHUNK_LOG) + 7)/8) == n_pages);
elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages);
}
@@ -726,7 +742,7 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers)
}
fcs_chunk_size_log = fcs->chunk_size_log;
if (fcs_chunk_size_log > MAX_BLOCKS_PER_CHUNK_LOG)
if (fcs_chunk_size_log > BLOCKS_PER_CHUNK_LOG)
{
elog(ERROR, "LFC: Invalid chunk size log: %u", fcs->chunk_size_log);
}
@@ -945,7 +961,7 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
{
BufferTag tag;
FileCacheEntry *entry;
uint32 hash;
uint64 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
@@ -958,14 +974,14 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (LFC_ENABLED())
{
for (BlockNumber blkno = 0; blkno < nblocks; blkno += lfc_blocks_per_chunk)
for (BlockNumber blkno = 0; blkno < nblocks; blkno += BLOCKS_PER_CHUNK)
{
tag.blockNum = blkno;
hash = get_hash_value(lfc_hash, &tag);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
if (entry != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
@@ -990,7 +1006,7 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
FileCacheEntry *entry;
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
bool found = false;
uint32 hash;
uint64 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
@@ -1000,12 +1016,12 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
tag.blockNum = blkno - chunk_offs;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
if (LFC_ENABLED())
{
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
found = entry != NULL && GET_STATE(entry, chunk_offs) != UNAVAILABLE;
}
LWLockRelease(lfc_lock);
@@ -1024,7 +1040,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
uint32 chunk_offs;
int found = 0;
uint32 hash;
uint64 hash;
int i = 0;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
@@ -1037,7 +1053,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
@@ -1048,12 +1064,12 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
while (true)
{
int this_chunk = Min(nblocks - i, lfc_blocks_per_chunk - chunk_offs);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
int this_chunk = Min(nblocks - i, BLOCKS_PER_CHUNK - chunk_offs);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
if (entry != NULL)
{
for (; chunk_offs < lfc_blocks_per_chunk && i < nblocks; chunk_offs++, i++)
for (; chunk_offs < BLOCKS_PER_CHUNK && i < nblocks; chunk_offs++, i++)
{
if (GET_STATE(entry, chunk_offs) != UNAVAILABLE)
{
@@ -1079,7 +1095,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
*/
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno + i);
tag.blockNum = (blkno + i) - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
}
LWLockRelease(lfc_lock);
@@ -1128,7 +1144,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
int blocks_read = 0;
@@ -1154,9 +1170,9 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
while (nblocks > 0)
{
struct iovec iov[PG_IOV_MAX];
uint8 chunk_mask[MAX_BLOCKS_PER_CHUNK / 8] = {0};
uint8 chunk_mask[BLOCKS_PER_CHUNK / 8] = {0};
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int iteration_hits = 0;
int iteration_misses = 0;
uint64 io_time_us = 0;
@@ -1206,7 +1222,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
Assert(iov_last_used - first_block_in_chunk_read >= n_blocks_to_read);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
@@ -1219,13 +1235,13 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
return blocks_read;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (entry == NULL)
@@ -1296,7 +1312,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
if (iteration_hits != 0)
{
/* chunk offset (# of pages) into the LFC file */
off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk;
off_t first_read_offset = (off_t) entry_offset * BLOCKS_PER_CHUNK;
int nwrite = iov_last_used - first_block_in_chunk_read;
/* offset of first IOV */
first_read_offset += chunk_offs + first_block_in_chunk_read;
@@ -1373,14 +1389,14 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* Returns false if there are no unpinned entries and chunk can not be added.
*/
static bool
lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
lfc_init_new_entry(FileCacheEntry *entry)
{
/*-----------
* If the chunk wasn't already in the LFC then we have these
* options, in order of preference:
*
* Unless there is no space available, we can:
* 1. Use an entry from the `holes` list, and
* 1. Use an entry from the freelist, and
* 2. Create a new entry.
* We can always, regardless of space in the LFC:
* 3. evict an entry from LRU, and
@@ -1388,17 +1404,10 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
*/
if (lfc_ctl->used < lfc_ctl->limit)
{
if (!dlist_is_empty(&lfc_ctl->holes))
if (!freelist_is_empty())
{
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key,
hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
uint32 offset = freelist_pop();
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1427,7 +1436,7 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->lru));
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
@@ -1436,24 +1445,21 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
CriticalAssert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
hash_search_with_hash_value(lfc_hash, &victim->key,
victim->hash, HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, victim);
neon_log(DEBUG2, "Swap file cache page");
}
else
{
/* Can't add this chunk - we don't have the space for it */
hash_search_with_hash_value(lfc_hash, &entry->key, hash,
HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, entry);
lfc_ctl->prewarm_canceled = true; /* cancel prewarm if LFC limit is reached */
return false;
}
entry->access_count = 1;
entry->hash = hash;
lfc_ctl->pinned += 1;
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
SET_STATE(entry, i, UNAVAILABLE);
return true;
@@ -1490,7 +1496,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
instr_time io_start, io_end;
@@ -1509,9 +1515,10 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1520,6 +1527,9 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
if (!freelist_prepare_pop())
goto retry;
lwlsn = neon_get_lwlsn(rinfo, forknum, blkno);
if (lwlsn > lsn)
@@ -1530,12 +1540,12 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (found)
{
@@ -1557,7 +1567,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry, hash))
if (!lfc_init_new_entry(entry))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1578,7 +1588,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwrite(lfc_desc, buffer, BLCKSZ,
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1640,7 +1650,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
int buf_offset = 0;
@@ -1653,6 +1663,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1662,6 +1673,9 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
generation = lfc_ctl->generation;
if (!freelist_prepare_pop())
goto retry;
/*
* For every chunk that has blocks we're interested in, we
* 1. get the chunk header
@@ -1675,7 +1689,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
{
struct iovec iov[PG_IOV_MAX];
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
instr_time io_start, io_end;
ConditionVariable* cv;
@@ -1688,16 +1702,16 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (found)
@@ -1714,7 +1728,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry, hash))
if (!lfc_init_new_entry(entry))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1763,7 +1777,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwritev(lfc_desc, iov, blocks_in_chunk,
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1823,6 +1837,140 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/**** freelist management ****/
/*
* Prerequisites:
* - The caller is holding 'lfc_lock'. XXX
*/
static bool
freelist_prepare_pop(void)
{
/*
* If the in-memory freelist is empty, but there are more blocks available, load them.
*
* TODO: if there
*/
if (lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET)
{
uint32 freelist_head;
FreeListChunk *freelist_chunk;
size_t bytes_read;
LWLockRelease(lfc_lock);
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
if (!(lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET))
{
/* someone else did the work for us while we were not holding the lock */
LWLockRelease(lfc_freelist_lock);
return false;
}
freelist_head = lfc_ctl->freelist_head;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
bytes_read = 0;
while (bytes_read < BLOCKS_PER_CHUNK * BLCKSZ)
{
ssize_t rc;
rc = pread(lfc_desc, freelist_chunk, BLOCKS_PER_CHUNK * BLCKSZ - bytes_read, (off_t) freelist_head * BLOCKS_PER_CHUNK * BLCKSZ + bytes_read);
if (rc < 0)
{
lfc_disable("read freelist page");
return false;
}
bytes_read += rc;
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_generation != lfc_ctl->generation)
{
LWLockRelease(lfc_lock);
return false;
}
Assert(lfc_ctl->freelist_head == freelist_head);
Assert(lfc_ctl->num_free_pages == 0);
lfc_ctl->freelist_head = freelist_chunk->next;
lfc_ctl->num_free_pages = freelist_chunk->num_free_pages;
memcpy(lfc_ctl->free_pages, freelist_chunk->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
pfree(freelist_chunk);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return false;
}
return true;
}
/*
* Prerequisites:
* - The caller is holding 'lfc_lock' and 'lfc_freelist_lock'.
*
* Returns 'false' on error.
*/
static bool
freelist_push(uint32 offset)
{
Assert(lfc_ctl->freelist_size == FREELIST_ENTRIES_PER_CHUNK);
if (lfc_ctl->num_free_pages == lfc_ctl->freelist_size)
{
FreeListChunk *freelist_chunk;
struct iovec iov;
ssize_t rc;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
/* write the existing entries to the chunk on disk */
freelist_chunk->next = lfc_ctl->freelist_head;
freelist_chunk->num_free_pages = lfc_ctl->num_free_pages;
memcpy(freelist_chunk->free_pages, lfc_ctl->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
/* Use the passed-in offset to hold the freelist chunk itself */
iov.iov_base = freelist_chunk;
iov.iov_len = BLOCKS_PER_CHUNK * BLCKSZ;
rc = pg_pwritev_with_retry(lfc_desc, &iov, 1, (off_t) offset * BLOCKS_PER_CHUNK * BLCKSZ);
pfree(freelist_chunk);
if (rc < 0)
return false;
lfc_ctl->freelist_head = offset;
lfc_ctl->num_free_pages = 0;
}
else
{
lfc_ctl->free_pages[lfc_ctl->num_free_pages] = offset;
lfc_ctl->num_free_pages++;
}
return true;
}
static uint32
freelist_pop(void)
{
uint32 result;
/* The caller should've checked that the list is not empty */
Assert(lfc_ctl->num_free_pages > 0);
result = lfc_ctl->free_pages[lfc_ctl->num_free_pages - 1];
lfc_ctl->num_free_pages--;
return result;
}
static bool
freelist_is_empty(void)
{
return lfc_ctl->num_free_pages == 0;
}
typedef struct
{
TupleDesc tupdesc;
@@ -1919,7 +2067,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
break;
case 8:
key = "file_cache_chunk_size_pages";
value = lfc_blocks_per_chunk;
value = BLOCKS_PER_CHUNK;
break;
case 9:
key = "file_cache_chunks_pinned";
@@ -1990,7 +2138,6 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
@@ -2046,15 +2193,16 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (LFC_ENABLED())
{
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
@@ -2076,25 +2224,28 @@ local_cache_pages(PG_FUNCTION_ARGS)
* in the fctx->record structure.
*/
uint32 n = 0;
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
for (uint32 pos = 0; pos < num_buckets; pos++)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
const BufferTag *key = file_cache_hash_get_key_for_entry(lfc_hash, entry);
if (GET_STATE(entry, i) == AVAILABLE)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
fctx->record[n].pageoffs = entry->offset * BLOCKS_PER_CHUNK + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(*key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].forknum = key->forkNum;
fctx->record[n].blocknum = key->blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}

View File

@@ -18,6 +18,11 @@ pub(super) async fn authenticate(
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");

View File

@@ -6,9 +6,10 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
@@ -97,11 +98,15 @@ impl ConsoleRedirectBackend {
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
) -> auth::Result<(
ConsoleRedirectNodeInfo,
ComputeUserInfo,
Option<Vec<IpPattern>>,
)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(|(node_info, auth_info, user_info)| {
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
.map(|(node_info, user_info, ip_allowlist)| {
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
})
}
}
@@ -116,6 +121,10 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
@@ -123,7 +132,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
// registering waiter can fail if we get unlucky with rng.
@@ -183,24 +192,10 @@ async fn authenticate(
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
let ssl_mode = if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
SslMode::Require
} else {
SslMode::Disable
};
let conn_info = compute::ConnectInfo {
host: db_info.host.into(),
port: db_info.port,
ssl_mode,
host_addr: None,
};
let auth_info =
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);
let user: RoleName = db_info.user.into();
let user_info = ComputeUserInfo {
@@ -214,12 +209,26 @@ async fn authenticate(
ctx.set_project(db_info.aux.clone());
info!("woken up a compute node");
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok((
NodeInfo {
conn_info,
config,
aux: db_info.aux,
},
auth_info,
user_info,
db_info.allowed_ips,
))
}

View File

@@ -1,12 +1,11 @@
use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use postgres_client::config::SslMode;
use tokio::sync::Semaphore;
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnectInfo;
use crate::compute::ConnCfg;
use crate::compute_ctl::ComputeCtlApi;
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
@@ -30,12 +29,7 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
conn_info: ConnectInfo {
host_addr: Some(postgres_addr.ip()),
host: postgres_addr.ip().to_string().into(),
port: postgres_addr.port(),
ssl_mode: SslMode::Disable,
},
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),

View File

@@ -168,6 +168,8 @@ impl ComputeUserInfo {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
@@ -417,6 +419,13 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
#[cfg(test)]

View File

@@ -169,6 +169,13 @@ pub(crate) async fn validate_password_and_exchange(
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;

View File

@@ -18,7 +18,6 @@ use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
@@ -101,13 +100,6 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
info!("invalidating endpoint access for `{endpoint_id}`");
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self

View File

@@ -24,6 +24,7 @@ use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type IpSubnetKey = IpNet;
@@ -496,8 +497,10 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
compute_config,
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
)
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;

View File

@@ -1,24 +1,21 @@
mod tls;
use std::fmt::Debug;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::net::SocketAddr;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
@@ -28,6 +25,7 @@ use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -40,7 +38,10 @@ pub(crate) enum ConnectionError {
Postgres(#[from] postgres_client::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] TlsError),
CouldNotConnect(#[from] io::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
@@ -72,7 +73,7 @@ impl UserFacingError for ConnectionError {
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
_ => COULD_NOT_CONNECT.to_owned(),
}
}
}
@@ -84,6 +85,7 @@ impl ReportableError for ConnectionError {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -94,85 +96,34 @@ impl ReportableError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
pub enum Auth {
/// Only used during console-redirect.
Password(Vec<u8>),
/// Used by sql-over-http, ws, tcp.
Scram(Box<ScramKeys>),
}
/// A config for authenticating to the compute node.
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
auth: Option<Auth>,
server_params: StartupMessageParams,
/// Console redirect sets user and database, we shouldn't re-use those from the params.
skip_db_user: bool,
}
/// Contains only the data needed to establish a secure connection to compute.
#[derive(Clone)]
pub struct ConnectInfo {
pub host_addr: Option<IpAddr>,
pub host: Host,
pub port: u16,
pub ssl_mode: SslMode,
}
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
/// Creation and initialization routines.
impl AuthInfo {
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
let mut server_params = StartupMessageParams::default();
server_params.insert("database", db);
server_params.insert("user", user);
Self {
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
server_params,
skip_db_user: true,
impl ConnCfg {
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}
/// Reuse password or auth keys from the other config.
pub(crate) fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
}
}
}
impl ConnectInfo {
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
config.ssl_mode(self.ssl_mode);
if let Some(host_addr) = self.host_addr {
config.set_host_addr(host_addr);
}
config
}
}
impl AuthInfo {
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
match &self.auth {
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
Some(Auth::Password(pw)) => config.password(pw),
None => &mut config,
};
for (k, v) in self.server_params.iter() {
config.set_param(k, v);
}
config
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(
@@ -181,26 +132,27 @@ impl AuthInfo {
arbitrary_params: bool,
) {
if !arbitrary_params {
self.server_params.insert("client_encoding", "UTF8");
self.set_param("client_encoding", "UTF8");
}
for (k, v) in params.iter() {
match k {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
"user" | "database" if self.skip_db_user => {}
"user" if self.user_is_set() => {}
"database" if self.db_is_set() => {}
"options" => {
if let Some(options) = filtered_options(v) {
self.server_params.insert(k, &options);
self.set_param(k, &options);
}
}
"user" | "database" | "application_name" | "replication" => {
self.server_params.insert(k, v);
self.set_param(k, v);
}
// if we allow arbitrary params, then we forward them through.
// this is a flag for a period of backwards compatibility
k if arbitrary_params => {
self.server_params.insert(k, v);
self.set_param(k, v);
}
_ => {}
}
@@ -208,13 +160,25 @@ impl AuthInfo {
}
}
impl ConnectInfo {
/// Establish a raw TCP+TLS connection to the compute node.
async fn connect_raw(
&self,
config: &ComputeConfig,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
impl std::ops::Deref for ConnCfg {
type Target = postgres_client::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
use postgres_client::config::Host;
// wrap TcpStream::connect with timeout
let connect_with_timeout = |addrs| {
@@ -244,32 +208,34 @@ impl ConnectInfo {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let port = self.port;
let host = &*self.host;
let port = self.0.get_port();
let host = self.0.get_host();
let addrs = match self.host_addr {
let host = match host {
Host::Tcp(host) => host.as_str(),
};
let addrs = match self.0.get_host_addr() {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port)).await?.collect(),
};
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
)),
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(TlsError::Connection(err))
Err(err)
}
}
}
}
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub(crate) stream:
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// PostgreSQL connection parameters.
pub(crate) params: std::collections::HashMap<String, String>,
/// Query cancellation token.
@@ -282,23 +248,28 @@ pub(crate) struct PostgresConnection {
_guage: NumDbConnectionsGuard<'static>,
}
impl ConnectInfo {
impl ConnCfg {
/// Connect to a corresponding compute node.
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
auth: &AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<PostgresConnection, ConnectionError> {
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
// we setup SSL early in `ConnectInfo::connect_raw`.
tmp_config.ssl_mode(SslMode::Disable);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let connection = tmp_config.connect_raw(stream, NoTls).await?;
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = self.0.connect_raw(stream, tls).await?;
drop(pause);
let RawConnection {
@@ -311,14 +282,13 @@ impl ConnectInfo {
tracing::Span::current().record("pid", tracing::field::display(process_id));
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
let MaybeTlsStream::Raw(stream) = stream.into_inner();
let stream = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.0.get_ssl_mode(),
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
@@ -329,11 +299,11 @@ impl ConnectInfo {
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.ssl_mode,
ssl_mode: self.0.get_ssl_mode(),
process_id,
secret_key,
},
self.host.to_string(),
host.to_string(),
user_info,
);

View File

@@ -1,63 +0,0 @@
use futures::FutureExt;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
Required,
}
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
}
}
}
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: &T,
host: &str,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
T: MakeTlsConnect<
S,
Error = InvalidDnsNameError,
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}

View File

@@ -210,20 +210,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, mut auth_info, user_info) = match backend
let (node_info, user_info, _ip_allowlist) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
auth_info.set_startup_params(&params, true);
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
auth: auth_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
&node_info,

View File

@@ -261,18 +261,24 @@ impl NeonControlPlaneClient {
Some(_) => SslMode::Require,
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => host.into(),
None => host.into(),
let host_name = match body.server_name {
Some(host) => host,
None => host.to_owned(),
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new(host_name, port);
if let Some(addr) = host_addr {
config.set_host_addr(addr);
}
config.ssl_mode(ssl_mode);
let node = NodeInfo {
conn_info: compute::ConnectInfo {
host_addr,
host,
port,
ssl_mode,
},
config,
aux: body.aux,
};

View File

@@ -6,7 +6,6 @@ use std::str::FromStr;
use std::sync::Arc;
use futures::TryFutureExt;
use postgres_client::config::SslMode;
use thiserror::Error;
use tokio_postgres::Client;
use tracing::{Instrument, error, info, info_span, warn};
@@ -15,7 +14,6 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
@@ -26,9 +24,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
use crate::{compute, scram};
#[derive(Debug, Error)]
enum MockApiError {
@@ -89,7 +87,8 @@ impl MockControlPlane {
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{role}' does not exist");
None
@@ -171,23 +170,25 @@ impl MockControlPlane {
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let port = self.endpoint.port().unwrap_or(5432);
let conn_info = match self.endpoint.host_str() {
None => ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
host: "localhost".into(),
port,
ssl_mode: SslMode::Disable,
},
Some(host) => ConnectInfo {
host_addr: IpAddr::from_str(host).ok(),
host: host.into(),
port,
ssl_mode: SslMode::Disable,
},
let mut config = match self.endpoint.host_str() {
None => {
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
config
}
Some(host) => {
let mut config = compute::ConnCfg::new(host.to_string(), port);
if let Ok(addr) = IpAddr::from_str(host) {
config.set_host_addr(addr);
}
config
}
};
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
conn_info,
config,
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -265,3 +266,12 @@ impl super::ControlPlaneApi for MockControlPlane {
self.do_wake_compute().map_ok(Cached::new_uncached).await
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {
let text = input.strip_prefix("md5")?;
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
Some(bytes)
}

View File

@@ -11,8 +11,8 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
@@ -39,6 +39,10 @@ pub mod mgmt;
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum AuthSecret {
#[cfg(any(test, feature = "testing"))]
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
@@ -59,9 +63,13 @@ pub(crate) struct AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
pub(crate) struct NodeInfo {
pub(crate) conn_info: compute::ConnectInfo,
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub(crate) config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
@@ -71,14 +79,26 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
auth: &compute::AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.conn_info
.connect(ctx, self.aux.clone(), auth, config, user_info)
self.config
.connect(ctx, self.aux.clone(), config, user_info)
.await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {
self.config.reuse_password(other.config);
}
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
};
}
}
#[derive(Copy, Clone, Default)]

View File

@@ -610,11 +610,11 @@ pub enum RedisEventsCount {
BranchCreated,
ProjectCreated,
CancelSession,
InvalidateRole,
InvalidateEndpoint,
InvalidateProject,
InvalidateProjects,
InvalidateOrg,
PasswordUpdate,
AllowedIpsUpdate,
AllowedVpcEndpointIdsUpdateForProjects,
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
BlockPublicOrVpcAccessUpdate,
}
pub struct ThreadPoolWorkers(usize);

View File

@@ -2,8 +2,8 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
@@ -13,6 +13,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
@@ -47,6 +48,8 @@ pub(crate) trait ConnectMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
@@ -55,17 +58,24 @@ pub(crate) trait ComputeConnectBackend {
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
fn get_keys(&self) -> &ComputeCredentialKeys;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
pub(crate) struct TcpMechanism<'a> {
pub(crate) params_compat: bool,
/// KV-dictionary with PostgreSQL connection params.
pub(crate) params: &'a StartupMessageParams,
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) user_info: ComputeUserInfo,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
impl ConnectMechanism for TcpMechanism<'_> {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -80,12 +90,13 @@ impl ConnectMechanism for TcpMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(
node_info
.connect(ctx, &self.auth, config, self.user_info.clone())
.await,
)
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params, self.params_compat);
}
}
@@ -103,9 +114,12 @@ where
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let node_info =
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
@@ -141,9 +155,14 @@ where
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
debug!("compute node's state has likely changed; requesting a wake-up");
invalidate_cache(node_info);
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);
node_info
};
// now that we have a new node, try connect to it repeatedly.

View File

@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncReadExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
@@ -53,28 +53,6 @@ impl fmt::Debug for ProtocolVersion {
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
@@ -88,38 +66,32 @@ macro_rules! read {
}};
}
/// Returns true if TLS is supported.
///
/// This is not cancel safe.
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let payload = StartupHeader {
len: 8.into(),
version: NEGOTIATE_SSL_CODE,
};
stream.write_all(payload.as_bytes()).await?;
stream.flush().await?;
// we expect back either `S` or `N` as a single byte.
let mut res = *b"0";
stream.read_exact(&mut res).await?;
debug_assert!(
res == *b"S" || res == *b"N",
"unexpected SSL negotiation response: {}",
char::from(res[0]),
);
// S for SSL.
Ok(res == *b"S")
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
@@ -592,9 +564,10 @@ mod tests {
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use super::ProtocolVersion;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
#[tokio::test]
async fn reject_large_startup() {
// we're going to define a v3.0 startup message with far too many parameters.

View File

@@ -358,19 +358,21 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: creds.info.clone(),
auth: auth_info,
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,

View File

@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
_ => false,
}
}
}

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -29,6 +29,7 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -71,14 +72,13 @@ struct ClientConfig<'a> {
hostname: &'a str,
}
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
Ok(crate::tls::postgres_rustls::make_tls_connect(
&self.config,
self.hostname,
)?)
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
}
}
@@ -497,6 +497,8 @@ impl ConnectMechanism for TestConnectMechanism {
x => panic!("expecting action {x:?}, connect is called instead"),
}
}
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
}
impl TestControlPlaneClient for TestConnectMechanism {
@@ -555,12 +557,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
ssl_mode: SslMode::Disable,
host_addr: None,
},
config: compute::ConnCfg::new("test".to_owned(), 5432),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -584,10 +581,7 @@ fn helper_create_connect_info(
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
keys: ComputeCredentialKeys::Password("password".into()),
},
)
}

View File

@@ -3,12 +3,12 @@ use std::sync::Arc;
use futures::StreamExt;
use redis::aio::PubSub;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -27,37 +27,42 @@ struct NotificationHeader<'a> {
topic: &'a str,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(tag = "topic", content = "data")]
enum Notification {
pub(crate) enum Notification {
#[serde(
rename = "/account_settings_update",
alias = "/allowed_vpc_endpoints_updated_for_org",
rename = "/allowed_ips_updated",
deserialize_with = "deserialize_json_string"
)]
AccountSettingsUpdate(InvalidateAccount),
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/endpoint_settings_update",
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
EndpointSettingsUpdate(InvalidateEndpoint),
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/project_settings_update",
alias = "/allowed_ips_updated",
alias = "/block_public_or_vpc_access_updated",
alias = "/allowed_vpc_endpoints_updated_for_projects",
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
ProjectSettingsUpdate(InvalidateProject),
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/role_setting_update",
alias = "/password_updated",
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
RoleSettingUpdate(InvalidateRole),
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
)]
PasswordUpdate { password_update: PasswordUpdate },
#[serde(
other,
@@ -67,56 +72,28 @@ enum Notification {
UnknownTopic,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateEndpoint {
EndpointId(EndpointIdInt),
EndpointIds(Vec<EndpointIdInt>),
}
impl std::ops::Deref for InvalidateEndpoint {
type Target = [EndpointIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::EndpointId(id) => std::slice::from_ref(id),
Self::EndpointIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateProject {
ProjectId(ProjectIdInt),
ProjectIds(Vec<ProjectIdInt>),
}
impl std::ops::Deref for InvalidateProject {
type Target = [ProjectIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::ProjectId(id) => std::slice::from_ref(id),
Self::ProjectIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateAccount {
AccountId(AccountIdInt),
AccountIds(Vec<AccountIdInt>),
}
impl std::ops::Deref for InvalidateAccount {
type Target = [AccountIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::AccountId(id) => std::slice::from_ref(id),
Self::AccountIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
account_id: AccountIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct InvalidateRole {
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
@@ -200,29 +177,41 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
tracing::debug!(?msg, "received a message");
match msg {
Notification::RoleSettingUpdate { .. }
| Notification::EndpointSettingsUpdate { .. }
| Notification::ProjectSettingsUpdate { .. }
| Notification::AccountSettingsUpdate { .. } => {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
let m = &Metrics::get().proxy.redis_events_count;
match msg {
Notification::RoleSettingUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateRole);
}
Notification::EndpointSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateEndpoint);
}
Notification::ProjectSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateProject);
}
Notification::AccountSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateOrg);
}
Notification::UnknownTopic => {}
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedIpsUpdate);
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
} else if matches!(
msg,
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
@@ -244,23 +233,30 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
match msg {
Notification::EndpointSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access(id)),
Notification::AccountSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
Notification::ProjectSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
Notification::RoleSettingUpdate(InvalidateRole {
project_id,
role_name,
}) => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate { project_id },
}
| Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
} => cache.invalidate_endpoint_access_for_project(project_id),
Notification::AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
} => cache.invalidate_endpoint_access_for_org(account_id),
Notification::AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects:
AllowedVpcEndpointsUpdatedForProjects { project_ids },
} => {
for project in project_ids {
cache.invalidate_endpoint_access_for_project(project);
}
}
Notification::PasswordUpdate {
password_update:
PasswordUpdate {
project_id,
role_name,
},
} => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::UnknownTopic => unreachable!(),
}
}
@@ -357,32 +353,11 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
);
Ok(())
}
#[test]
fn parse_multiple_projects() -> anyhow::Result<()> {
let project_id1: ProjectId = "new_project1".into();
let project_id2: ProjectId = "new_project2".into();
let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
let text = json!({
"type": "message",
"topic": "/allowed_vpc_endpoints_updated_for_projects",
"data": data,
"extre_fields": "something"
})
.to_string();
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
(&project_id1).into(),
(&project_id2).into()
]))
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate {
project_id: (&project_id).into()
}
}
);
Ok(())
@@ -404,10 +379,12 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::RoleSettingUpdate(InvalidateRole {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
})
Notification::PasswordUpdate {
password_update: PasswordUpdate {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
}
}
);
Ok(())

View File

@@ -23,6 +23,7 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
@@ -304,13 +305,12 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let (key, jwk) = create_random_jwk();
let mut config = local_backend
.node_info
.conn_info
.to_postgres_client_config();
config
let config = node_info
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
@@ -322,7 +322,7 @@ impl PoolingBackend {
);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
let (client, connection) = config.connect(postgres_client::NoTls).await?;
drop(pause);
let pid = client.get_process_id();
@@ -336,7 +336,7 @@ impl PoolingBackend {
connection,
key,
conn_id,
local_backend.node_info.aux.clone(),
node_info.aux.clone(),
);
{
@@ -512,16 +512,19 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
let mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
let res = config.connect(mk_tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -545,6 +548,8 @@ impl ConnectMechanism for TokioMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
@@ -568,20 +573,20 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let host_addr = node_info.config.get_host_addr();
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
let port = node_info.config.get_port();
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -604,6 +609,8 @@ impl ConnectMechanism for HyperMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(

View File

@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {

View File

@@ -2,11 +2,10 @@ use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
use rustls::pki_types::{InvalidDnsNameError, ServerName};
use rustls::ClientConfig;
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::config::ComputeConfig;
mod private {
use std::future::Future;
use std::io;
@@ -124,27 +123,36 @@ mod private {
}
}
impl<S> MakeTlsConnect<S> for ComputeConfig
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = InvalidDnsNameError;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
make_tls_connect(&self.tls, hostname)
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
}
}
pub fn make_tls_connect(
tls: &Arc<rustls::ClientConfig>,
hostname: &str,
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: tls.clone().into(),
})
})
}

View File

@@ -8,8 +8,8 @@ use std::error::Error as _;
use http_utils::error::HttpErrorBody;
use reqwest::{IntoUrl, Method, StatusCode};
use safekeeper_api::models::{
self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization,
TimelineCreateRequest, TimelineStatus,
self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest,
TimelineStatus,
};
use utils::id::{NodeId, TenantId, TimelineId};
use utils::logging::SecretString;
@@ -183,12 +183,6 @@ impl Client {
self.get(&uri).await
}
pub async fn status(&self) -> Result<SafekeeperStatus> {
let uri = format!("{}/v1/status", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn utilization(&self) -> Result<SafekeeperUtilization> {
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;

View File

@@ -395,8 +395,6 @@ pub enum TimelineError {
Cancelled(TenantTimelineId),
#[error("Timeline {0} was not found in global map")]
NotFound(TenantTimelineId),
#[error("Timeline {0} has been deleted")]
Deleted(TenantTimelineId),
#[error("Timeline {0} creation is in progress")]
CreationInProgress(TenantTimelineId),
#[error("Timeline {0} exists on disk, but wasn't loaded on startup")]

View File

@@ -78,13 +78,7 @@ impl GlobalTimelinesState {
Some(GlobalMapTimeline::CreationInProgress) => {
Err(TimelineError::CreationInProgress(*ttid))
}
None => {
if self.has_tombstone(ttid) {
Err(TimelineError::Deleted(*ttid))
} else {
Err(TimelineError::NotFound(*ttid))
}
}
None => Err(TimelineError::NotFound(*ttid)),
}
}

View File

@@ -65,9 +65,8 @@ diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connec
diesel_migrations = { version = "2.2.0" }
scoped-futures = "0.1.4"
compute_api = { path = "../libs/compute_api/" }
http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -1 +0,0 @@
ALTER TABLE nodes DROP COLUMN lifecycle;

View File

@@ -1 +0,0 @@
ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active';

View File

@@ -428,10 +428,7 @@ impl ComputeHook {
.expect("Unknown pageserver");
let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
compute_api::spec::Pageserver {
host: pg_host,
port: pg_port.unwrap_or(5432),
}
(pg_host, pg_port.unwrap_or(5432))
})
.collect::<Vec<_>>();

View File

@@ -907,42 +907,6 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
json_response(StatusCode::OK, state.service.node_delete(node_id).await?)
}
async fn handle_tombstone_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let mut nodes = state.service.tombstone_list().await?;
nodes.sort_by_key(|n| n.get_id());
let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::<Vec<_>>();
json_response(StatusCode::OK, api_nodes)
}
async fn handle_tombstone_delete(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(
StatusCode::OK,
state.service.tombstone_delete(node_id).await?,
)
}
async fn handle_node_configure(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
@@ -2098,20 +2062,6 @@ pub fn make_router(
.post("/debug/v1/node/:node_id/drop", |r| {
named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop"))
})
.delete("/debug/v1/tombstone/:node_id", |r| {
named_request_span(
r,
handle_tombstone_delete,
RequestName("debug_v1_tombstone_delete"),
)
})
.get("/debug/v1/tombstone", |r| {
named_request_span(
r,
handle_tombstone_list,
RequestName("debug_v1_tombstone_list"),
)
})
.post("/debug/v1/tenant/:tenant_id/import", |r| {
named_request_span(
r,

View File

@@ -2,7 +2,7 @@ use std::str::FromStr;
use std::time::Duration;
use pageserver_api::controller_api::{
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest,
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest,
NodeSchedulingPolicy, TenantLocateResponseShard,
};
use pageserver_api::shard::TenantShardId;
@@ -29,7 +29,6 @@ pub(crate) struct Node {
availability: NodeAvailability,
scheduling: NodeSchedulingPolicy,
lifecycle: NodeLifecycle,
listen_http_addr: String,
listen_http_port: u16,
@@ -229,7 +228,6 @@ impl Node {
listen_pg_addr,
listen_pg_port,
scheduling: NodeSchedulingPolicy::Active,
lifecycle: NodeLifecycle::Active,
availability: NodeAvailability::Offline,
availability_zone_id,
use_https,
@@ -241,7 +239,6 @@ impl Node {
NodePersistence {
node_id: self.id.0 as i64,
scheduling_policy: self.scheduling.into(),
lifecycle: self.lifecycle.into(),
listen_http_addr: self.listen_http_addr.clone(),
listen_http_port: self.listen_http_port as i32,
listen_https_port: self.listen_https_port.map(|x| x as i32),
@@ -266,7 +263,6 @@ impl Node {
availability: NodeAvailability::Offline,
scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy)
.expect("Bad scheduling policy in DB"),
lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"),
listen_http_addr: np.listen_http_addr,
listen_http_port: np.listen_http_port as u16,
listen_https_port: np.listen_https_port.map(|x| x as u16),

View File

@@ -19,7 +19,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use itertools::Itertools;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy,
AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
};
use pageserver_api::models::{ShardImportStatus, TenantConfig};
@@ -102,7 +102,6 @@ pub(crate) enum DatabaseOperation {
UpdateNode,
DeleteNode,
ListNodes,
ListTombstones,
BeginShardSplit,
CompleteShardSplit,
AbortShardSplit,
@@ -358,8 +357,6 @@ impl Persistence {
}
/// When a node is first registered, persist it before using it for anything
/// If the provided node_id already exists, it will be error.
/// The common case is when a node marked for deletion wants to register.
pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> {
let np = &node.to_persistent();
self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| {
@@ -376,41 +373,19 @@ impl Persistence {
/// At startup, populate the list of nodes which our shards may be placed on
pub(crate) async fn list_nodes(&self) -> DatabaseResult<Vec<NodePersistence>> {
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
let nodes: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListNodes, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_nodes: loaded {} nodes", result.len());
tracing::info!("list_nodes: loaded {} nodes", nodes.len());
Ok(result)
}
pub(crate) async fn list_tombstones(&self) -> DatabaseResult<Vec<NodePersistence>> {
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListTombstones, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_tombstones: loaded {} nodes", result.len());
Ok(result)
Ok(nodes)
}
pub(crate) async fn update_node<V>(
@@ -429,7 +404,6 @@ impl Persistence {
Box::pin(async move {
let updated = diesel::update(nodes)
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.set(values)
.execute(conn)
.await?;
@@ -473,57 +447,6 @@ impl Persistence {
.await
}
/// Tombstone is a special state where the node is not deleted from the database,
/// but it is not available for usage.
/// The main reason for it is to prevent the flaky node to register.
pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.update_node(
del_node_id,
lifecycle.eq(String::from(NodeLifecycle::Deleted)),
)
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
// You can hard delete a node only if it has a tombstone.
// So we need to check if the node has lifecycle set to deleted.
let node_to_delete = nodes
.filter(node_id.eq(del_node_id.0 as i64))
.first::<NodePersistence>(conn)
.await
.optional()?;
if let Some(np) = node_to_delete {
let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| {
DatabaseError::Logical(format!(
"Node {} has invalid lifecycle: {}",
del_node_id, e
))
})?;
if lc != NodeLifecycle::Deleted {
return Err(DatabaseError::Logical(format!(
"Node {} was not soft deleted before, cannot hard delete it",
del_node_id
)));
}
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
}
Ok(())
})
})
.await
}
/// At startup, load the high level state for shards, such as their config + policy. This will
/// be enriched at runtime with state discovered on pageservers.
///
@@ -620,6 +543,21 @@ impl Persistence {
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
Ok(())
})
})
.await
}
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
/// the node that called /re-attach.
@@ -633,20 +571,6 @@ impl Persistence {
let updated = self
.with_measured_conn(DatabaseOperation::ReAttach, move |conn| {
Box::pin(async move {
// Check if the node is not marked as deleted
let deleted_node: i64 = nodes
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.count()
.get_result(conn)
.await?;
if deleted_node > 0 {
return Err(DatabaseError::Logical(format!(
"Node {} is marked as deleted, re-attach is not allowed",
input_node_id
)));
}
let rows_updated = diesel::update(tenant_shards)
.filter(generation_pageserver.eq(input_node_id.0 as i64))
.set(generation.eq(generation + 1))
@@ -2124,7 +2048,6 @@ pub(crate) struct NodePersistence {
pub(crate) listen_pg_port: i32,
pub(crate) availability_zone_id: String,
pub(crate) listen_https_port: Option<i32>,
pub(crate) lifecycle: String,
}
/// Tenant metadata health status that are stored durably.

View File

@@ -33,7 +33,6 @@ diesel::table! {
listen_pg_port -> Int4,
availability_zone_id -> Varchar,
listen_https_port -> Nullable<Int4>,
lifecycle -> Varchar,
}
}

View File

@@ -166,7 +166,6 @@ enum NodeOperations {
Register,
Configure,
Delete,
DeleteTombstone,
}
/// The leadership status for the storage controller process.
@@ -1108,8 +1107,7 @@ impl Service {
observed
}
/// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us
/// tenants from pageservers.
/// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers.
///
/// This is safe to run in the background, because if we don't have this TenantShardId in our map of
/// tenants, then it is probably something incompletely deleted before: we will not fight with any
@@ -6212,11 +6210,7 @@ impl Service {
}
}
fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict(
"failpoint".to_string()
)));
pausable_failpoint!("shard-split-pre-complete-pause");
pausable_failpoint!("shard-split-pre-complete");
// TODO: if the pageserver restarted concurrently with our split API call,
// the actual generation of the child shard might differ from the generation
@@ -6238,15 +6232,6 @@ impl Service {
let (response, child_locations, waiters) =
self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size);
// Notify all page servers to detach and clean up the old shards because they will no longer
// be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent
// Pageserver re-attach/startup.
let shards_to_cleanup = targets
.iter()
.map(|target| (target.parent_id, target.node.get_id()))
.collect();
self.cleanup_locations(shards_to_cleanup).await;
// Send compute notifications for all the new shards
let mut failed_notifications = Vec::new();
for (child_id, child_ps, stripe_size) in child_locations {
@@ -6924,7 +6909,7 @@ impl Service {
/// detaching or deleting it on pageservers. We do not try and re-schedule any
/// tenants that were on this node.
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
self.persistence.set_tombstone(node_id).await?;
self.persistence.delete_node(node_id).await?;
let mut locked = self.inner.write().unwrap();
@@ -7048,10 +7033,9 @@ impl Service {
// That is safe because in Service::spawn we only use generation_pageserver if it refers to a node
// that exists.
// 2. Actually delete the node from in-memory state and set tombstone to the database
// for preventing the node to register again.
// 2. Actually delete the node from the database and from in-memory state
tracing::info!("Deleting node from database");
self.persistence.set_tombstone(node_id).await?;
self.persistence.delete_node(node_id).await?;
Ok(())
}
@@ -7070,35 +7054,6 @@ impl Service {
Ok(nodes)
}
pub(crate) async fn tombstone_list(&self) -> Result<Vec<Node>, ApiError> {
self.persistence
.list_tombstones()
.await?
.into_iter()
.map(|np| Node::from_persistent(np, false))
.collect::<Result<Vec<_>, _>>()
.map_err(ApiError::InternalServerError)
}
pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> {
let _node_lock = trace_exclusive_lock(
&self.node_op_locks,
node_id,
NodeOperations::DeleteTombstone,
)
.await;
if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) {
self.persistence.delete_node(node_id).await?;
Ok(())
} else {
Err(ApiError::Conflict(format!(
"Node {} is in use, consider using tombstone API first",
node_id
)))
}
}
pub(crate) async fn get_node(&self, node_id: NodeId) -> Result<Node, ApiError> {
self.inner
.read()
@@ -7269,25 +7224,7 @@ impl Service {
};
match registration_status {
RegistrationStatus::New => {
self.persistence.insert_node(&new_node).await.map_err(|e| {
if matches!(
e,
crate::persistence::DatabaseError::Query(
diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)
)
) {
// The node can be deleted by tombstone API, and not show up in the list of nodes.
// If you see this error, check tombstones first.
ApiError::Conflict(format!("Node {} is already exists", new_node.get_id()))
} else {
ApiError::from(e)
}
})?;
}
RegistrationStatus::New => self.persistence.insert_node(&new_node).await?,
RegistrationStatus::NeedUpdate => {
self.persistence
.update_node_on_registration(

View File

@@ -2054,14 +2054,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def tombstone_delete(self, node_id):
log.info(f"tombstone_delete({node_id})")
self.request(
"DELETE",
f"{self.api}/debug/v1/tombstone/{node_id}",
headers=self.headers(TokenScope.ADMIN),
)
def node_drain(self, node_id):
log.info(f"node_drain({node_id})")
self.request(
@@ -2118,14 +2110,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
)
return response.json()
def tombstone_list(self):
response = self.request(
"GET",
f"{self.api}/debug/v1/tombstone",
headers=self.headers(TokenScope.ADMIN),
)
return response.json()
def tenant_shard_dump(self):
"""
Debug listing API: dumps the internal map of tenant shards

View File

@@ -87,9 +87,6 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build
# Set up pageserver for import
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
}
env = neon_env_builder.init_start()
env.pageserver.tenant_create(tenant)

View File

@@ -30,7 +30,6 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -199,7 +198,6 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder)
env.pageserver.allowed_errors.extend(
[
".*Timeline.*was cancelled.*",
".*Timeline.*has been deleted.*",
".*Timeline.*was not found.*",
]
)

View File

@@ -1836,90 +1836,3 @@ def test_sharding_gc(
shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"])
log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}")
assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn
def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder):
"""
Check that PageServer only deletes old shards after the split is committed such that it doesn't
have to download a lot of files during abort.
"""
DBNAME = "regression"
init_shard_count = 4
neon_env_builder.num_pageservers = init_shard_count
stripe_size = 32
env = neon_env_builder.init_start(
initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size
)
env.storage_controller.allowed_errors.extend(
[
# All split failures log a warning when they enqueue the abort operation
".*Enqueuing background abort.*",
# Tolerate any error logs that mention a failpoint
".*failpoint.*",
]
)
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
# Write some initial data.
endpoint.safe_psql(f"CREATE DATABASE {DBNAME}")
endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);")
for _ in range(1000):
endpoint.safe_psql(
"INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False
)
# Record how many bytes we've downloaded before the split.
def collect_downloaded_bytes() -> list[float | None]:
downloaded_bytes = []
for page_server in env.pageservers:
metric = page_server.http_client().get_metric_value(
"pageserver_remote_ondemand_downloaded_bytes_total"
)
downloaded_bytes.append(metric)
return downloaded_bytes
downloaded_bytes_before = collect_downloaded_bytes()
# Attempt to split the tenant, but fail the split before it completes.
env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)"))
with pytest.raises(StorageControllerApiException):
env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16)
# Wait until split is aborted.
def check_split_is_aborted():
tenants = env.storage_controller.tenant_list()
assert len(tenants) == 1
shards = tenants[0]["shards"]
assert len(shards) == 4
for shard in shards:
assert not shard["is_splitting"]
assert not shard["is_reconciling"]
# Make sure all new shards have been deleted.
valid_shards = 0
for ps in env.pageservers:
for tenant_dir in os.listdir(ps.workdir / "tenants"):
try:
tenant_shard_id = TenantShardId.parse(tenant_dir)
valid_shards += 1
assert tenant_shard_id.shard_count == 4
except ValueError:
log.info(f"{tenant_dir} is not valid tenant shard id")
assert valid_shards >= 4
wait_until(check_split_is_aborted)
endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False)
# Make sure we didn't download anything following the aborted split.
downloaded_bytes_after = collect_downloaded_bytes()
assert downloaded_bytes_before == downloaded_bytes_after
endpoint.stop_and_destroy()

View File

@@ -2956,7 +2956,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.allowed_errors.extend(
[".*Unexpected child shard count.*", ".*Enqueuing background abort.*"]
)
pause_failpoint = "shard-split-pre-complete-pause"
pause_failpoint = "shard-split-pre-complete"
env.storage_controller.configure_failpoints((pause_failpoint, "pause"))
split_fut = executor.submit(
@@ -3003,7 +3003,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.request(
"PUT",
f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints",
json=[{"name": pause_failpoint, "actions": "off"}],
json=[{"name": "shard-split-pre-complete", "actions": "off"}],
headers=env.storage_controller.headers(TokenScope.ADMIN),
)
@@ -3093,58 +3093,6 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
wait_until(reconfigure_node_again)
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_pageservers = 3
env = neon_env_builder.init_start()
def assert_nodes_count(n: int):
nodes = env.storage_controller.node_list()
assert len(nodes) == n
# Nodes count must remain the same before deletion
assert_nodes_count(3)
ps = env.pageservers[0]
env.storage_controller.node_delete(ps.id)
# After deletion, the node count must be reduced
assert_nodes_count(2)
# Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
log.info("Restarting tombstoned pageserver...")
ps.stop()
ps_start_fut = executor.submit(lambda: ps.start(await_active=False))
# After deleted pageserver restart, the node count must remain the same
assert_nodes_count(2)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id
env.storage_controller.tombstone_delete(ps.id)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 0
# Wait for the pageserver start operation to complete.
# If it fails with an exception, we try restarting the pageserver since the failure
# may be due to the storage controller refusing to register the node.
# However, if we get a TimeoutError that means the pageserver is completely hung,
# which is an unexpected failure mode that we'll let propagate up.
try:
ps_start_fut.result(timeout=20)
except TimeoutError:
raise
except Exception:
log.info("Restarting deleted pageserver...")
ps.restart()
# Finally, the node can be registered again after tombstone is deleted
wait_until(lambda: assert_nodes_count(3))
def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder):
"""
The storage controller is meant to handle the case where a timeline CRUD operation races

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import pytest
import requests
from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import (
PAGESERVER_GLOBAL_METRICS,
@@ -299,65 +299,6 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde
assert post_detach_samples == set()
def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder):
"""Tests that when a timeline is offloaded, the tenant specific metrics are not left behind"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
tenant_1, _ = env.create_tenant()
timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1)
timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1)
endpoint_tenant1 = env.endpoints.create_start(
"test_metrics_removed_after_offload_1", tenant_id=tenant_1
)
endpoint_tenant2 = env.endpoints.create_start(
"test_metrics_removed_after_offload_2", tenant_id=tenant_1
)
for endpoint in [endpoint_tenant1, endpoint_tenant2]:
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("CREATE TABLE t(key int primary key, value text)")
cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'")
cur.execute("SELECT sum(key) FROM t")
assert cur.fetchone() == (5000050000,)
endpoint.stop()
def get_ps_metric_samples_for_timeline(
tenant_id: TenantId, timeline_id: TimelineId
) -> list[Sample]:
ps_metrics = env.pageserver.http_client().get_metrics()
samples = []
for metric_name in ps_metrics.metrics:
for sample in ps_metrics.query_all(
name=metric_name,
filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)},
):
samples.append(sample)
return samples
for timeline in [timeline_1, timeline_2]:
pre_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}"
env.pageserver.http_client().timeline_archival_config(
tenant_1,
timeline,
state=TimelineArchivalState.ARCHIVED,
)
env.pageserver.http_client().timeline_offload(tenant_1, timeline)
post_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert post_offload_samples == set()
def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()

View File

@@ -433,7 +433,6 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -1935,7 +1934,6 @@ def test_membership_api(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)