Compare commits

..

5 Commits

Author SHA1 Message Date
Heikki Linnakangas
16d6898e44 git add missing file 2025-06-12 02:37:59 +03:00
Heikki Linnakangas
10b936bf03 Use a custom Rust implementation to replace the LFC hash table
The new implementation lives in a separately allocated shared memory
area, which could be resized. Resizing it isn't actually implemented
yet, though. It would require some co-operation from the LFC code.
2025-06-05 18:31:29 +03:00
Heikki Linnakangas
6145cfd1c2 Move neon-shmem facility to separate module within the crate 2025-06-05 18:13:03 +03:00
Heikki Linnakangas
96b4de1de6 Make LFC chunk size a compile-time constant
A runtime setting is nicer, but the next commit will replace the hash
table with a different implementation that requires the value size to
be a compile-time constant.
2025-06-05 18:08:40 +03:00
Heikki Linnakangas
9fdf5fbb7e Use a separate freelist to track LFC "holes"
When the LFC is shrunk, we punch holes in the underlying file to
release the disk space to the OS. We tracked it in the same hash table
as the in-use entries, because that was convenient. However, I'm
working on being able to shrink the hash table too, and once we do
that, we'll need some other place to track the holes. Implement a
simple scheme of an in-memory array and a chain of on-disk blocks for
that.
2025-06-05 18:08:35 +03:00
87 changed files with 3036 additions and 2653 deletions

128
Cargo.lock generated
View File

@@ -1086,6 +1086,25 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
dependencies = [
"clap",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1212,7 +1231,7 @@ version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -1270,6 +1289,14 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"cbindgen",
"neon-shmem",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -1445,7 +1472,6 @@ dependencies = [
"regex",
"reqwest",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
"serde",
"serde_json",
@@ -1937,7 +1963,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
dependencies = [
"darling",
"either",
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -2055,7 +2081,6 @@ dependencies = [
"axum-extra",
"camino",
"camino-tempfile",
"clap",
"futures",
"http-body-util",
"itertools 0.10.5",
@@ -2502,6 +2527,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2714,6 +2751,12 @@ dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3650,7 +3693,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3712,7 +3755,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"twox-hash",
]
@@ -3801,6 +3844,8 @@ name = "neon-shmem"
version = "0.1.0"
dependencies = [
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -5094,7 +5139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
@@ -5115,7 +5160,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
@@ -5240,7 +5285,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"rcgen",
"redis",
"regex",
@@ -5344,6 +5389,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5368,6 +5419,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5388,6 +5449,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5406,6 +5477,15 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5416,6 +5496,16 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6902,7 +6992,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck",
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
@@ -8201,6 +8291,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8558,6 +8657,15 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"

View File

@@ -44,6 +44,7 @@ members = [
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -251,6 +252,7 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
@@ -278,6 +280,7 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -110,19 +110,6 @@ RUN set -e \
# System postgres for use with client libraries (e.g. in storage controller)
postgresql-15 \
openssl \
unzip \
curl \
&& ARCH=$(uname -m) \
&& if [ "$ARCH" = "x86_64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
elif [ "$ARCH" = "aarch64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
else \
echo "Unsupported architecture: $ARCH" && exit 1; \
fi \
&& unzip awscliv2.zip \
&& ./aws/install \
&& rm -rf aws awscliv2.zip \
&& rm -f /etc/apt/apt.conf.d/80-retries \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& useradd -d /data neon \

View File

@@ -18,10 +18,12 @@ ifeq ($(BUILD_TYPE),release)
PG_LDFLAGS = $(LDFLAGS)
# Unfortunately, `--profile=...` is a nightly feature
CARGO_BUILD_FLAGS += --release
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -180,11 +182,16 @@ postgres-check-%: postgres-%
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
+@echo "Compiling communicator $*"
$(CARGO_CMD_PREFIX) cargo build -p communicator $(CARGO_BUILD_FLAGS)
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
LIBCOMMUNICATOR_PATH=$(NEON_CARGO_ARTIFACT_TARGET_DIR) \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \

View File

@@ -36,7 +36,6 @@ pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
safekeeper_client.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true
http-utils.workspace = true

View File

@@ -45,7 +45,7 @@ use pageserver_api::models::{
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::membership::SafekeeperGeneration;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -1255,45 +1255,6 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
pageserver
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
.await?;
if env.storage_controller.timelines_onto_safekeepers {
println!("Creating timeline on safekeeper ...");
let timeline_info = pageserver
.timeline_info(
TenantShardId::unsharded(tenant_id),
timeline_id,
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
)
.await?;
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
let default_host = default_sk
.conf
.listen_addr
.clone()
.unwrap_or_else(|| "localhost".to_string());
let mconf = safekeeper_api::membership::Configuration {
generation: SafekeeperGeneration::new(1),
members: safekeeper_api::membership::MemberSet {
m: vec![SafekeeperId {
host: default_host,
id: default_sk.conf.id,
pg_port: default_sk.conf.pg_port,
}],
},
new_members: None,
};
let pg_version = args.pg_version * 10000;
let req = safekeeper_api::models::TimelineCreateRequest {
tenant_id,
timeline_id,
mconf,
pg_version,
system_id: None,
wal_seg_size: None,
start_lsn: timeline_info.last_record_lsn,
commit_lsn: None,
};
default_sk.create_timeline(&req).await?;
}
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
println!("Done");
}

View File

@@ -635,16 +635,4 @@ impl PageServerNode {
Ok(())
}
pub async fn timeline_info(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
) -> anyhow::Result<TimelineInfo> {
let timeline_info = self
.http_client
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
.await?;
Ok(timeline_info)
}
}

View File

@@ -6,6 +6,7 @@
//! .neon/safekeepers/<safekeeper id>
//! ```
use std::error::Error as _;
use std::future::Future;
use std::io::Write;
use std::path::PathBuf;
use std::time::Duration;
@@ -13,9 +14,9 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use http_utils::error::HttpErrorBody;
use postgres_connection::PgConnectionConfig;
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::id::NodeId;
@@ -34,14 +35,25 @@ pub enum SafekeeperHttpError {
type Result<T> = result::Result<T, SafekeeperHttpError>;
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
use mgmt_api::Error::*;
match err {
ApiError(_, str) => SafekeeperHttpError::Response(str),
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
pub(crate) trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
}
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(SafekeeperHttpError::Response(
match self.json::<HttpErrorBody>().await {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
))
}
}
@@ -58,8 +70,9 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: mgmt_api::Client,
pub http_client: reqwest::Client,
pub listen_addr: String,
pub http_base_url: String,
}
impl SafekeeperNode {
@@ -69,14 +82,13 @@ impl SafekeeperNode {
} else {
"127.0.0.1".to_string()
};
let jwt = None;
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
env: env.clone(),
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
http_client: env.create_http_client(),
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
listen_addr,
}
}
@@ -266,19 +278,20 @@ impl SafekeeperNode {
)
}
pub async fn check_status(&self) -> Result<()> {
self.http_client
.status()
.await
.map_err(err_from_client_err)?;
Ok(())
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
// TODO: authentication
//if self.env.auth_type == AuthType::NeonJWT {
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
//}
self.http_client.request(method, url)
}
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
self.http_client
.create_timeline(req)
.await
.map_err(err_from_client_err)?;
pub async fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()
.await?
.error_from_body()
.await?;
Ok(())
}
}

View File

@@ -61,16 +61,10 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
// Set a node status as deleted.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Delete a tombstone of node from the storage controller.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
},
/// Modify a tenant's policies in the storage controller
TenantPolicy {
#[arg(long)]
@@ -88,8 +82,6 @@ enum Command {
},
/// List nodes known to the storage controller
Nodes {},
/// List soft deleted nodes known to the storage controller
NodeTombstones {},
/// List tenants known to the storage controller
Tenants {
/// If this field is set, it will list the tenants on a specific node
@@ -908,39 +900,6 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("debug/v1/tombstone/{node_id}"),
None,
)
.await?;
}
Command::NodeTombstones {} => {
let mut resp = storcon_client
.dispatch::<(), Vec<NodeDescribeResponse>>(
Method::GET,
"debug/v1/tombstone".to_string(),
None,
)
.await?;
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
let mut table = comfy_table::Table::new();
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
for node in resp {
table.add_row([
format!("{}", node.id),
node.listen_http_addr,
node.availability_zone_id,
format!("{:?}", node.scheduling),
format!("{:?}", node.availability),
]);
}
println!("{table}");
}
Command::TenantSetTimeBasedEviction {
tenant_id,
period,

View File

@@ -8,7 +8,6 @@ anyhow.workspace = true
axum-extra.workspace = true
axum.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
jsonwebtoken.workspace = true
prometheus.workspace = true

View File

@@ -4,8 +4,6 @@
//! for large computes.
mod app;
use anyhow::Context;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -14,26 +12,9 @@ const fn max_upload_file_limit() -> usize {
100 * 1024 * 1024
}
const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
config_file: Option<String>,
#[arg(long, default_value = "false", requires = "config")]
/// to allow testing k8s helm chart where we don't have s3 credentials
no_s3_check_on_startup: bool,
#[arg(long, value_name = "FILE")]
/// inline config mode for k8s helm chart
config: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
#[serde(default = "listen")]
listen: std::net::SocketAddr,
pemfile: camino::Utf8PathBuf,
#[serde(flatten)]
@@ -50,18 +31,13 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let args = Args::parse();
let config: Config = if let Some(config_path) = args.config_file {
info!("Reading config from {config_path}");
let config = std::fs::read_to_string(config_path)?;
serde_json::from_str(&config).context("parsing config")?
} else if let Some(config) = args.config {
info!("Reading inline config");
serde_json::from_str(&config).context("parsing config")?
} else {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: endpoint_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
let config: Config = serde_json::from_str(&config).context("parsing config")?;
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
@@ -72,9 +48,7 @@ async fn main() -> anyhow::Result<()> {
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
if !args.no_s3_check_on_startup {
app::check_storage_permissions(&storage, cancel.clone()).await?;
}
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,

View File

@@ -6,8 +6,12 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace=true
nix.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"

304
libs/neon-shmem/src/hash.rs Normal file
View File

@@ -0,0 +1,304 @@
//! Hash table implementation on top of 'shmem'
//!
//! Features required in the long run by the communicator project:
//!
//! [X] Accessible from both Postgres processes and rust threads in the communicator process
//! [X] Low latency
//! [ ] Scalable to lots of concurrent accesses (currently relies on caller for locking)
//! [ ] Resizable
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::CoreHashMap;
use entry::{Entry, OccupiedEntry};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub struct HashMapInit<'a, K, V> {
// Hash table can be allocated in a fixed memory area, or in a resizeable ShmemHandle.
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
pub struct HashMapAccess<'a, K, V> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
unsafe impl<'a, K: Sync, V: Sync> Sync for HashMapAccess<'a, K, V> {}
unsafe impl<'a, K: Send, V: Send> Send for HashMapAccess<'a, K, V> {}
impl<'a, K, V> HashMapInit<'a, K, V> {
pub fn attach_writer(self) -> HashMapAccess<'a, K, V> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
}
}
pub fn attach_reader(self) -> HashMapAccess<'a, K, V> {
// no difference to attach_writer currently
self.attach_writer()
}
}
/// This is stored in the shared memory area
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// HashMapShared
/// [buckets]
/// [dictionary]
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
struct HashMapShared<'a, K, V> {
inner: CoreHashMap<'a, K, V>,
}
impl<'a, K, V> HashMapInit<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
pub fn init_in_fixed_area(
num_buckets: u32,
area: &'a mut [MaybeUninit<u8>],
) -> HashMapInit<'a, K, V> {
Self::init_common(num_buckets, None, area.as_mut_ptr().cast(), area.len())
}
/// Initialize a new hash map in the given shared memory area
pub fn init_in_shmem(num_buckets: u32, mut shmem: ShmemHandle) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = unsafe { shmem.data_ptr.as_mut() };
Self::init_common(num_buckets, Some(shmem), ptr, size)
}
fn init_common(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_len: usize,
) -> HashMapInit<'a, K, V> {
// carve out the HashMapShared struct from the area.
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { area_ptr.add(area_len) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
}
HashMapInit {
shmem_handle: shmem_handle,
shared_ptr,
}
}
}
impl<'a, K, V> HashMapAccess<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn get_hash_value(&self, key: &K) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_with_hash(key, hash)
}
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_with_hash(key, hash)
}
pub fn remove_with_hash(&mut self, key: &K, hash: u64) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
match map.inner.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => {
e.remove();
}
Entry::Vacant(_) => {}
};
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_at_bucket(pos)
}
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map. (An Iterator might be nicer. The communicator's
/// clock algorithm needs to _slowly_ iterate through all buckets with its clock hand,
/// without holding a lock. If we switch to an Iterator, it must not hold the lock.)
pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
if pos >= map.inner.buckets.len() {
return None;
}
let bucket = &map.inner.buckets[pos];
bucket.inner.as_ref()
}
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let origin = map.inner.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / (size_of::<V>() as usize);
assert!(idx < map.inner.buckets.len());
idx
}
// for metrics
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.buckets_in_use as usize
}
/// Grow
///
/// 1. grow the underlying shared memory area
/// 2. Initialize new buckets. This overwrites the current dictionary
/// 3. Recalculate the dictionary
pub fn grow(&mut self, num_buckets: u32) -> Result<(), crate::shmem::Error> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let old_num_buckets = inner.buckets.len() as u32;
if num_buckets < old_num_buckets {
panic!("grow called with a smaller number of buckets");
}
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites
// the dictionary!
let buckets_ptr = inner.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket_ptr = buckets_ptr.add(i as usize);
bucket_ptr.write(core::Bucket {
next: if i < num_buckets {
i as u32 + 1
} else {
inner.free_head
},
inner: None,
});
}
}
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = core::INVALID_POS;
}
for i in 0..old_num_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let mut hasher = DefaultHasher::new();
buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher);
let hash = hasher.finish();
let pos: usize = (hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
inner.free_head = old_num_buckets;
Ok(())
}
// TODO: Shrinking is a multi-step process that requires co-operation from the caller
//
// 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered
// buckets.
//
// 2. Next, the caller must evict all entries in higher-numbered buckets.
//
// 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying
// shmem area
}

View File

@@ -0,0 +1,174 @@
//! Simple hash table with chaining
//!
//! # Resizing
//!
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
pub(crate) const INVALID_POS: u32 = u32::MAX;
// Bucket
pub(crate) struct Bucket<K, V> {
pub(crate) next: u32,
pub(crate) inner: Option<(K, V)>,
}
pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) dictionary: &'a mut [u32],
pub(crate) buckets: &'a mut [Bucket<K, V>],
pub(crate) free_head: u32,
pub(crate) _user_list_head: u32,
// metrics
pub(crate) buckets_in_use: u32,
}
#[derive(Debug)]
pub struct FullError();
impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V>
where
K: Clone + Hash + Eq,
{
const FILL_FACTOR: f32 = 0.60;
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> CoreHashMap<'a, K, V> {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for i in 0..dictionary.len() {
dictionary[i].write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
CoreHashMap {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
_user_list_head: INVALID_POS,
}
}
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(&bucket_value);
}
next = bucket.next;
}
}
// all updates are done through Entry
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let dict_pos = hash as usize % self.dictionary.len();
let first = self.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut self.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map: self,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<K, V>> {
if pos >= self.buckets.len() {
return None;
}
todo!()
//self.buckets[pos].inner.as_ref()
}
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let pos = self.free_head;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.buckets[pos as usize];
self.free_head = bucket.next;
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
return Ok(pos);
}
}

View File

@@ -0,0 +1,91 @@
//! Like std::collections::hash_map::Entry;
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
pub(crate) enum PrevPos {
First(u32),
Chained(u32),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) _key: K, // The key of the occupied entry
pub(crate) prev_pos: PrevPos,
pub(crate) bucket_pos: u32, // The position of the bucket in the CoreHashMap's buckets array
}
impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
let old_value = mem::replace(&mut bucket.inner.as_mut().unwrap().1, value);
old_value
}
pub fn remove(self) -> V {
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match self.prev_pos {
PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = bucket.next,
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next
}
}
// and add it to the freelist
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = self.map.free_head;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
return old_value.unwrap().1;
}
}
pub struct VacantEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) key: K, // The key to insert
pub(crate) dict_pos: u32,
}
impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> {
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.map.buckets[pos as usize];
bucket.next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1;
return Ok(result);
}
}

View File

@@ -0,0 +1,220 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::UpdateAction;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let w = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
//eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
sut: &HashMapAccess<TestKey, TestValue>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
sut.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
#[test]
fn random_ops() {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(100000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}
#[test]
fn test_grow() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(1000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1000) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
writer.grow(1500).unwrap();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1500) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -1,418 +1,4 @@
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}
pub mod hash;
pub mod shmem;

View File

@@ -0,0 +1,418 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -344,35 +344,6 @@ impl Default for ShardSchedulingPolicy {
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeLifecycle {
Active,
Deleted,
}
impl FromStr for NodeLifecycle {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"deleted" => Ok(Self::Deleted),
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
}
}
}
impl From<NodeLifecycle> for String {
fn from(value: NodeLifecycle) -> String {
use NodeLifecycle::*;
match value {
Active => "active",
Deleted => "deleted",
}
.to_string()
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeSchedulingPolicy {
Active,

View File

@@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
tls: T,
mut tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>

View File

@@ -17,6 +17,7 @@ use crate::{Client, Connection, Error};
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
@@ -230,7 +231,7 @@ impl Config {
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: &T,
tls: T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,

View File

@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
pub async fn connect<T>(
tls: &T,
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where

View File

@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{env, io};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, Result};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
@@ -37,7 +37,6 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
Version, VersionKind,
};
pub struct AzureBlobStorage {
@@ -406,39 +405,6 @@ impl AzureBlobStorage {
pub fn container_name(&self) -> &str {
&self.container_name
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
// We do not return this info back to `VersionListing` yet.
builder = builder.include_deleted(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
}
trait ListingCollector {
@@ -522,10 +488,27 @@ impl RemoteStorage for AzureBlobStorage {
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
async fn head_object(
@@ -820,158 +803,14 @@ impl RemoteStorage for AzureBlobStorage {
async fn time_travel_recover(
&self,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
_prefix: Option<&RemotePath>,
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
.to_string()
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
+ "to the time during the file was not present, this functionality will recover the file. Only "
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
+ "pageserver tenants.";
tracing::error!("{}", msg);
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
versions_and_deletes.len()
);
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_name(key);
if last_vd.last_modified > done_if_after {
tracing::debug!("Key {key} has version later than done_if_after, skipping");
continue;
}
// the version we want to restore to.
let version_to_restore_to =
match versions.binary_search_by_key(&timestamp, |tpl| tpl.last_modified) {
Ok(v) => v,
Err(e) => e,
};
if version_to_restore_to == versions.len() {
tracing::debug!("Key {key} has no changes since timestamp, skipping");
continue;
}
let mut do_delete = false;
if version_to_restore_to == 0 {
// All versions more recent, so the key didn't exist at the specified time point.
tracing::debug!(
"All {} versions more recent for {key}, deleting",
versions.len()
);
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
Version {
kind: VersionKind::Version(version_id),
..
} => {
let source_url = format!(
"{}/{}?versionid={}",
self.client
.url()
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
key,
version_id.0
);
tracing::debug!(
"Promoting old version {} for {key} at {}...",
version_id.0,
source_url
);
backoff::retry(
|| async {
let blob_client = self.client.blob_client(key.clone());
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"copying object version for time_travel_recover",
cancel,
)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
}
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
}
}
};
if do_delete {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::debug!("Key {key} already deleted, skipping.");
} else {
tracing::debug!("Deleting {key}...");
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
.await
.map_err(|e| {
// delete_oid0 will use TimeoutOrCancel
if TimeoutOrCancel::caused_by_cancel(&e) {
TimeTravelError::Cancelled
} else {
TimeTravelError::Other(e)
}
})?;
}
}
}
Ok(())
// TODO use Azure point in time recovery feature for this
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
Err(TimeTravelError::Unimplemented)
}
}

View File

@@ -1022,7 +1022,6 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
use crate::membership::Configuration;
use crate::{ServerInfo, Term};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize)]
pub struct SafekeeperStatus {
pub id: NodeId,
}

View File

@@ -23,7 +23,6 @@ use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::page_service::GrpcPageServiceHandler;
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
};
@@ -573,8 +572,7 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = mgr::init(
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -585,10 +583,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
);
))?;
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),
@@ -816,7 +814,7 @@ fn start_pageserver(
// necessary?
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
page_service_grpc = Some(page_service::spawn_grpc(
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -87,35 +86,7 @@ impl FeatureResolver {
}
}
}
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
// TODO: add pageserver URL.
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -1053,15 +1053,6 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("Failed to register pageserver_tenant_states_count metric")
});
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_timeline_states_count",
"Count of timelines per state",
&["state"]
)
.expect("Failed to register pageserver_timeline_states_count metric")
});
/// A set of broken tenants.
///
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
@@ -3334,8 +3325,6 @@ impl TimelineMetrics {
&timeline_id,
);
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
TimelineMetrics {
tenant_id,
shard_id,
@@ -3490,8 +3479,6 @@ impl TimelineMetrics {
return;
}
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
let tenant_id = &self.tenant_id;
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;

View File

@@ -169,6 +169,99 @@ pub fn spawn(
Listener { cancel, task }
}
/// Spawns a gRPC server for the page service.
///
/// TODO: move this onto GrpcPageServiceHandler::spawn().
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn_grpc(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
impl Listener {
pub async fn stop_accepting(self) -> Connections {
self.cancel.cancel();
@@ -3273,101 +3366,6 @@ pub struct GrpcPageServiceHandler {
}
impl GrpcPageServiceHandler {
/// Spawns a gRPC server for the page service.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
listener,
)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
// Run the page service.
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]

View File

@@ -89,8 +89,7 @@ use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
remove_tenant_metrics,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -545,28 +544,6 @@ pub struct OffloadedTimeline {
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
pub deleted_from_ancestor: AtomicBool,
_metrics_guard: OffloadedTimelineMetricsGuard,
}
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
struct OffloadedTimelineMetricsGuard;
impl OffloadedTimelineMetricsGuard {
fn new() -> Self {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.inc();
Self
}
}
impl Drop for OffloadedTimelineMetricsGuard {
fn drop(&mut self) {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.dec();
}
}
impl OffloadedTimeline {
@@ -599,8 +576,6 @@ impl OffloadedTimeline {
delete_progress: timeline.delete_progress.clone(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
})
}
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
@@ -620,7 +595,6 @@ impl OffloadedTimeline {
archived_at,
delete_progress: TimelineDeleteProgress::default(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
}
}
fn manifest(&self) -> OffloadedTimelineManifest {
@@ -1858,29 +1832,6 @@ impl TenantShard {
}
}
// At this point we've initialized all timelines and are tracking them.
// Now compute the layer visibility for all (not offloaded) timelines.
let compute_visiblity_for = {
let timelines_accessor = self.timelines.lock().unwrap();
let mut timelines_offloaded_accessor = self.timelines_offloaded.lock().unwrap();
timelines_offloaded_accessor.extend(offloaded_timelines_list.into_iter());
// Before activation, populate each Timeline's GcInfo with information about its children
self.initialize_gc_info(&timelines_accessor, &timelines_offloaded_accessor, None);
timelines_accessor.values().cloned().collect::<Vec<_>>()
};
for tl in compute_visiblity_for {
tl.update_layer_visibility().await.with_context(|| {
format!(
"failed initial timeline visibility computation {} for tenant {}",
tl.timeline_id, self.tenant_shard_id
)
})?;
}
// Walk through deleted timelines, resume deletion
for (timeline_id, index_part, remote_timeline_client) in timelines_to_resume_deletions {
remote_timeline_client
@@ -1900,6 +1851,10 @@ impl TenantShard {
.context("resume_deletion")
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
}
{
let mut offloaded_timelines_accessor = self.timelines_offloaded.lock().unwrap();
offloaded_timelines_accessor.extend(offloaded_timelines_list.into_iter());
}
// Stash the preloaded tenant manifest, and upload a new manifest if changed.
//
@@ -3462,6 +3417,9 @@ impl TenantShard {
.values()
.filter(|timeline| !(timeline.is_broken() || timeline.is_stopping()));
// Before activation, populate each Timeline's GcInfo with information about its children
self.initialize_gc_info(&timelines_accessor, &timelines_offloaded_accessor, None);
// Spawn gc and compaction loops. The loops will shut themselves
// down when they notice that the tenant is inactive.
tasks::start_background_loops(self, background_jobs_can_start);

File diff suppressed because it is too large Load Diff

View File

@@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
/// Argument to [`Timeline::shutdown`].
#[derive(Debug, Clone, Copy)]
pub(crate) enum ShutdownMode {
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
/// take multiple seconds for a busy timeline.
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
///
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
/// the call to [`Timeline::shutdown`].
@@ -3407,6 +3407,10 @@ impl Timeline {
// TenantShard::create_timeline will wait for these uploads to happen before returning, or
// on retry.
// Now that we have the full layer map, we may calculate the visibility of layers within it (a global scan)
drop(guard); // drop write lock, update_layer_visibility will take a read lock.
self.update_layer_visibility().await?;
info!(
"loaded layer map with {} layers at {}, total physical size: {}",
num_layers, disk_consistent_lsn, total_physical_size
@@ -5897,7 +5901,7 @@ impl Drop for Timeline {
if let Ok(mut gc_info) = ancestor.gc_info.write() {
if !gc_info.remove_child_not_offloaded(self.timeline_id) {
tracing::error!(tenant_id = %self.tenant_shard_id.tenant_id, shard_id = %self.tenant_shard_id.shard_slug(), timeline_id = %self.timeline_id,
"Couldn't remove retain_lsn entry from timeline's parent on drop: already removed");
"Couldn't remove retain_lsn entry from offloaded timeline's parent: already removed");
}
}
}

View File

@@ -1,6 +1,5 @@
# pgxs/neon/Makefile
MODULE_big = neon
OBJS = \
$(WIN32RES) \
@@ -22,7 +21,8 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
control_plane_connector.o \
walsender_hooks.o
walsender_hooks.o \
$(LIBCOMMUNICATOR_PATH)/libcommunicator.a
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)

View File

@@ -0,0 +1,13 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
[lib]
crate-type = ["staticlib"]
[dependencies]
neon-shmem.workspace = true
[build-dependencies]
cbindgen.workspace = true

View File

@@ -0,0 +1,8 @@
This package will evolve into a "compute-pageserver communicator"
process and machinery. For now, it just provides wrappers on the
neon-shmem Rust crate, to allow using it in the C implementation of
the LFC.
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -0,0 +1,22 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {:?}", e),
},
|bindings| {
bindings.write_to_file("communicator_bindings.h");
},
);
Ok(())
}

View File

@@ -0,0 +1,4 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -0,0 +1,240 @@
//! Glue code to allow using the Rust shmem hash map implementation from C code
//!
//! For convience of adapting existing code, the interface provided somewhat resembles the dynahash
//! interface.
//!
//! NOTE: The caller is responsible for locking! The caller is expected to hold the PostgreSQL
//! LWLock, 'lfc_lock', while accessing the hash table, in shared or exclusive mode as appropriate.
use std::ffi::c_void;
use std::marker::PhantomData;
use neon_shmem::hash::entry::Entry;
use neon_shmem::hash::{HashMapAccess, HashMapInit};
use neon_shmem::shmem::ShmemHandle;
/// NB: This must match the definition of BufferTag in Postgres C headers. We could use bindgen to
/// generate this from the C headers, but prefer to not introduce dependency on bindgen for now.
///
/// Note that there are no padding bytes. If the corresponding C struct has padding bytes, the C C
/// code must clear them.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
/// Like with FileCacheKey, this must match the definition of FileCacheEntry in file_cache.c. We
/// don't look at the contents here though, it's sufficent that the size and alignment matches.
#[derive(Clone, Debug, Default)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
/// XXX: This could be just:
///
/// ```ignore
/// type FileCacheHashMapHandle = HashMapInit<'a, FileCacheKey, FileCacheEntry>
/// ```
///
/// but with that, cbindgen generates a broken typedef in the C header file which doesn't
/// compile. It apparently gets confused by the generics.
#[repr(transparent)]
pub struct FileCacheHashMapHandle<'a>(
pub *mut c_void,
PhantomData<HashMapInit<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapHandle<'a> {
fn from(x: Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>) -> Self {
FileCacheHashMapHandle(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> From<FileCacheHashMapHandle<'a>> for Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>> {
fn from(x: FileCacheHashMapHandle) -> Self {
unsafe { Box::from_raw(x.0.cast()) }
}
}
/// XXX: same for this
#[repr(transparent)]
pub struct FileCacheHashMapAccess<'a>(
pub *mut c_void,
PhantomData<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapAccess<'a> {
fn from(x: Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>) -> Self {
// Convert the Box into a raw mutable pointer to the HashMapAccess itself.
// This transfers ownership of the HashMapAccess (and its contained ShmemHandle)
// to the raw pointer. The C caller is now responsible for managing this memory.
FileCacheHashMapAccess(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> FileCacheHashMapAccess<'a> {
fn as_ref(self) -> &'a HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_ref().unwrap() }
}
fn as_mut(self) -> &'a mut HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_mut().unwrap() }
}
}
/// Initialize the shared memory area at postmaster startup. The returned handle is inherited
/// by all the backend processes across fork()
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_init<'a>(
initial_num_buckets: u32,
max_num_buckets: u32,
) -> FileCacheHashMapHandle<'a> {
let max_bytes = HashMapInit::<FileCacheKey, FileCacheEntry>::estimate_size(max_num_buckets);
let shmem_handle =
ShmemHandle::new("lfc mapping", 0, max_bytes).expect("shmem initialization failed");
let handle = HashMapInit::<FileCacheKey, FileCacheEntry>::init_in_shmem(
initial_num_buckets,
shmem_handle,
);
Box::new(handle).into()
}
/// Initialize the access to the shared memory area in a backend process.
///
/// XXX: I'm not sure if this actually gets called in each process, or if the returned struct
/// is also inherited across fork(). It currently works either way but if this did more
/// initialization that needed to be done after fork(), then it would matter.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_access<'a>(
handle: FileCacheHashMapHandle<'a>,
) -> FileCacheHashMapAccess<'a> {
let handle: Box<HashMapInit<'_, FileCacheKey, FileCacheEntry>> = handle.into();
Box::new(handle.attach_writer()).into()
}
/// Return the current number of buckets in the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_num_buckets<'a>(
map: FileCacheHashMapAccess<'static>,
) -> u32 {
let map = map.as_ref();
map.get_num_buckets().try_into().unwrap()
}
/// Look up the entry with given key and hash.
///
/// This is similar to dynahash's hash_search(... , HASH_FIND)
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_find<'a>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
hash: u64,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_with_hash(key, hash)
}
/// Look up the entry at given bucket position
///
/// This has no direct equivalent in the dynahash interface, but can be used to
/// iterate through all entries in the hash table.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_at_pos<'a>(
map: FileCacheHashMapAccess<'static>,
pos: u32,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_at_bucket(pos as usize).map(|(_k, v)| v)
}
/// Remove entry, given a pointer to the value.
///
/// This is equivalent to dynahash hash_search(entry->key, HASH_REMOVE), where 'entry'
/// is an entry you have previously looked up
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_remove_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *mut FileCacheEntry,
) {
let map = map.as_mut();
let pos = map.get_bucket_for_value(entry);
match map.entry_at_bucket(pos) {
Some(e) => {
e.remove();
}
None => {
// todo: shouldn't happen, panic?
}
}
}
/// Compute the hash for given key
///
/// This is equivalent to dynahash get_hash_value() function. We use Rust's default hasher
/// for calculating the hash though.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_hash_value<'a, 'b>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
) -> u64 {
map.as_ref().get_hash_value(key)
}
/// Insert a new entry to the hash table
///
/// This is equivalent to dynahash hash_search(..., HASH_ENTER).
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_enter<'a, 'b>(
map: FileCacheHashMapAccess,
key: &FileCacheKey,
hash: u64,
found: &mut bool,
) -> *mut FileCacheEntry {
match map.as_mut().entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => {
*found = true;
e.get_mut()
}
Entry::Vacant(e) => {
*found = false;
let initial_value = FileCacheEntry::default();
e.insert(initial_value).expect("TODO: hash table full")
}
}
}
/// Get the key for a given entry, which must be present in the hash table.
///
/// Dynahash requires the key to be part of the "value" struct, so you can always
/// access the key with something like `entry->key`. The Rust implementation however
/// stores the key separately. This function extracts the separately stored key.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_key_for_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *const FileCacheEntry,
) -> Option<&FileCacheKey> {
let map = map.as_ref();
let pos = map.get_bucket_for_value(entry);
map.get_at_bucket(pos as usize).map(|(k, _v)| k)
}
/// Remove all entries from the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_reset<'a, 'b>(map: FileCacheHashMapAccess) {
let map = map.as_mut();
let num_buckets = map.get_num_buckets();
for i in 0..num_buckets {
if let Some(e) = map.entry_at_bucket(i) {
e.remove();
}
}
}

View File

@@ -0,0 +1 @@
pub mod file_cache_hashmap;

View File

@@ -21,7 +21,7 @@
#include "access/xlog.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "common/hashfn.h"
#include "common/file_utils.h"
#include "pgstat.h"
#include "port/pg_iovec.h"
#include "postmaster/bgworker.h"
@@ -36,7 +36,6 @@
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
@@ -46,6 +45,7 @@
#include "hll.h"
#include "bitmap.h"
#include "file_cache.h"
#include "file_cache_rust_hash.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "neon_perf_counters.h"
@@ -64,7 +64,7 @@
*
* Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about
* its consistency.
*
*
* ## Holes
*
@@ -76,13 +76,15 @@
* fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't
* shrink, but the disk space it uses does.
*
* Each hole is tracked by a dummy FileCacheEntry, which are kept in the
* 'holes' linked list. They are entered into the chunk hash table, with a
* special key where the blockNumber is used to store the 'offset' of the
* hole, and all other fields are zero. Holes are never looked up in the hash
* table, we only enter them there to have a FileCacheEntry that we can keep
* in the linked list. If the soft limit is raised again, we reuse the holes
* before extending the nominal size of the file.
* Each hole is tracked in a freelist. The freelist consists of two parts: a
* fixed-size array in shared memory, and a linked chain of on-disk
* blocks. When the in-memory array fills up, it's flushed to a new on-disk
* chunk. If the soft limit is raised again, we reuse the holes before
* extending the nominal size of the file.
*
* The in-memory freelist array is protected by 'lfc_lock', while the on-disk
* chain is protected by a separate 'lfc_freelist_lock'. Locking rule to
* avoid deadlocks: always acquire lfc_freelist_lock first, then lfc_lock.
*/
/* Local file storage allocation chunk.
@@ -92,13 +94,15 @@
* 1Mb chunks can reduce hash map size to 320Mb.
* 2. Improve access locality, subsequent pages will be allocated together improving seqscan speed
*/
#define MAX_BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define MAX_BLOCKS_PER_CHUNK (1 << MAX_BLOCKS_PER_CHUNK_LOG)
#define BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define BLOCKS_PER_CHUNK (1 << BLOCKS_PER_CHUNK_LOG)
#define MB ((uint64)1024*1024)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> lfc_chunk_size_log))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (lfc_blocks_per_chunk-1))
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> BLOCKS_PER_CHUNK_LOG))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (BLOCKS_PER_CHUNK-1))
#define INVALID_OFFSET (0xffffffff)
/*
* Blocks are read or written to LFC file outside LFC critical section.
@@ -119,15 +123,18 @@ typedef enum FileCacheBlockState
typedef struct FileCacheEntry
{
BufferTag key;
uint32 hash;
uint32 offset;
uint32 access_count;
dlist_node list_node; /* LRU/holes list node */
uint32 state[FLEXIBLE_ARRAY_MEMBER]; /* two bits per block */
dlist_node list_node; /* LRU list node */
uint32 state[(BLOCKS_PER_CHUNK * 2 + 31) / 32]; /* two bits per block */
} FileCacheEntry;
#define FILE_CACHE_ENRTY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
/* Todo: alignment must be the same too */
StaticAssertDecl(sizeof(FileCacheEntry) == sizeof(RustFileCacheEntry),
"Rust and C declarations of FileCacheEntry are incompatible");
StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
"Rust and C declarations of FileCacheKey are incompatible");
#define GET_STATE(entry, i) (((entry)->state[(i) / 16] >> ((i) % 16 * 2)) & 3)
#define SET_STATE(entry, i, new_state) (entry)->state[(i) / 16] = ((entry)->state[(i) / 16] & ~(3 << ((i) % 16 * 2))) | ((new_state) << ((i) % 16 * 2))
@@ -136,6 +143,9 @@ typedef struct FileCacheEntry
#define MAX_PREWARM_WORKERS 8
#define FREELIST_ENTRIES_PER_CHUNK (BLOCKS_PER_CHUNK * BLCKSZ / sizeof(uint32) - 2)
typedef struct PrewarmWorkerState
{
uint32 prewarmed_pages;
@@ -161,7 +171,6 @@ typedef struct FileCacheControl
uint64 evicted_pages; /* number of evicted pages */
dlist_head lru; /* double linked list for LRU replacement
* algorithm */
dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
@@ -172,23 +181,39 @@ typedef struct FileCacheControl
bool prewarm_active;
bool prewarm_canceled;
dsm_handle prewarm_lfc_state_handle;
/*
* Free list. This is large enough to hold one chunks worth of entries.
*/
uint32 freelist_size;
uint32 freelist_head;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FileCacheControl;
typedef struct FreeListChunk
{
uint32 next;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FreeListChunk;
#define FILE_CACHE_STATE_MAGIC 0xfcfcfcfc
#define FILE_CACHE_STATE_BITMAP(fcs) ((uint8*)&(fcs)->chunks[(fcs)->n_chunks])
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * lfc_blocks_per_chunk)+7)/8)
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * BLOCKS_PER_CHUNK)+7)/8)
#define FILE_CACHE_STATE_SIZE(fcs) (sizeof(FileCacheState) + (fcs->n_chunks)*sizeof(BufferTag) + (((fcs->n_chunks) << fcs->chunk_size_log)+7)/8)
static HTAB *lfc_hash;
static FileCacheHashMapHandle lfc_hash_handle;
static FileCacheHashMapAccess lfc_hash;
static int lfc_desc = -1;
static LWLockId lfc_lock;
static LWLockId lfc_freelist_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_prewarm_limit;
static int lfc_prewarm_batch;
static int lfc_chunk_size_log = MAX_BLOCKS_PER_CHUNK_LOG;
static int lfc_blocks_per_chunk = MAX_BLOCKS_PER_CHUNK;
static int lfc_blocks_per_chunk_ro = BLOCKS_PER_CHUNK;
static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
@@ -205,6 +230,11 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
static bool freelist_push(uint32 offset);
static bool freelist_prepare_pop(void);
static uint32 freelist_pop(void);
static bool freelist_is_empty(void);
/*
* Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled.
@@ -232,15 +262,9 @@ lfc_switch_off(void)
if (LFC_ENABLED())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
/* Invalidate hash */
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search_with_hash_value(lfc_hash, &entry->key, entry->hash, HASH_REMOVE, NULL);
}
file_cache_hash_reset(lfc_hash);
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->pinned = 0;
@@ -248,7 +272,9 @@ lfc_switch_off(void)
lfc_ctl->used_pages = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
dlist_init(&lfc_ctl->holes);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
/*
* We need to use unlink to to avoid races in LFC write, because it is not
@@ -317,8 +343,8 @@ lfc_ensure_opened(void)
static void
lfc_shmem_startup(void)
{
size_t size;
bool found;
static HASHCTL info;
if (prev_shmem_startup_hook)
{
@@ -327,27 +353,29 @@ lfc_shmem_startup(void)
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
size = sizeof(FileCacheControl);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", size, &found);
if (!found)
{
int fd;
uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock");
info.keysize = sizeof(BufferTag);
info.entrysize = FILE_CACHE_ENRTY_SIZE;
lfc_freelist_lock = (LWLockId) GetNamedLWLockTranche("lfc_freelist_lock");
/*
* n_chunks+1 because we add new element to hash table before eviction
* of victim
*/
lfc_hash = ShmemInitHash("lfc_hash",
n_chunks + 1, n_chunks + 1,
&info,
HASH_ELEM | HASH_BLOBS);
memset(lfc_ctl, 0, sizeof(FileCacheControl));
lfc_hash_handle = file_cache_hash_shmem_init(n_chunks + 1, n_chunks + 1);
memset(lfc_ctl, 0, offsetof(FileCacheControl, free_pages));
dlist_init(&lfc_ctl->lru);
dlist_init(&lfc_ctl->holes);
lfc_ctl->freelist_size = FREELIST_ENTRIES_PER_CHUNK;
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
/* Initialize hyper-log-log structure for estimating working set size */
initSHLL(&lfc_ctl->wss_estimation);
@@ -371,18 +399,25 @@ lfc_shmem_startup(void)
}
LWLockRelease(AddinShmemInitLock);
lfc_hash = file_cache_hash_shmem_access(lfc_hash_handle);
}
static void
lfc_shmem_request(void)
{
size_t size;
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
size = sizeof(FileCacheControl);
RequestAddinShmemSpace(size);
RequestNamedLWLockTranche("lfc_lock", 1);
RequestNamedLWLockTranche("lfc_freelist_lock", 2);
}
static bool
@@ -398,24 +433,6 @@ is_normal_backend(void)
return lfc_ctl && MyProc && UsedShmemSegAddr && !IsParallelWorker();
}
static bool
lfc_check_chunk_size(int *newval, void **extra, GucSource source)
{
if (*newval & (*newval - 1))
{
elog(ERROR, "LFC chunk size should be power of two");
return false;
}
return true;
}
static void
lfc_change_chunk_size(int newval, void* extra)
{
lfc_chunk_size_log = pg_ceil_log2_32(newval);
}
static bool
lfc_check_limit_hook(int *newval, void **extra, GucSource source)
{
@@ -435,12 +452,14 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ctl || !is_normal_backend())
return;
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
/* Open LFC file only if LFC was enabled or we are going to reenable it */
if (newval == 0 && !LFC_ENABLED())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
/* File should be reopened if LFC is reenabled */
lfc_close_file();
return;
@@ -449,6 +468,7 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ensure_opened())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
@@ -464,35 +484,30 @@ lfc_change_limit_hook(int newval, void *extra)
* returning their space to file system
*/
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
FileCacheEntry *hole;
uint32 offset = victim->offset;
uint32 hash;
bool found;
BufferTag holetag;
CriticalAssert(victim->access_count == 0);
#ifdef FALLOC_FL_PUNCH_HOLE
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * lfc_blocks_per_chunk * BLCKSZ, lfc_blocks_per_chunk * BLCKSZ) < 0)
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * BLOCKS_PER_CHUNK * BLCKSZ, BLOCKS_PER_CHUNK * BLCKSZ) < 0)
neon_log(LOG, "Failed to punch hole in file: %m");
#endif
/* We remove the old entry, and re-enter a hole to the hash table */
for (int i = 0; i < lfc_blocks_per_chunk; i++)
/* We remove the entry, and enter a hole to the freelist */
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
lfc_ctl->evicted_pages += is_page_cached;
}
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, victim);
memset(&holetag, 0, sizeof(holetag));
holetag.blockNum = offset;
hash = get_hash_value(lfc_hash, &holetag);
hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found);
hole->hash = hash;
hole->offset = offset;
hole->access_count = 0;
CriticalAssert(!found);
dlist_push_tail(&lfc_ctl->holes, &hole->list_node);
if (!freelist_push(offset))
{
/* freelist_push already logged the error */
lfc_switch_off();
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
lfc_ctl->used -= 1;
}
@@ -504,6 +519,7 @@ lfc_change_limit_hook(int newval, void *extra)
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
}
void
@@ -579,14 +595,14 @@ lfc_init(void)
DefineCustomIntVariable("neon.file_cache_chunk_size",
"LFC chunk size in blocks (should be power of two)",
NULL,
&lfc_blocks_per_chunk,
MAX_BLOCKS_PER_CHUNK,
1,
MAX_BLOCKS_PER_CHUNK,
PGC_POSTMASTER,
&lfc_blocks_per_chunk_ro,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
PGC_INTERNAL,
GUC_UNIT_BLOCKS,
lfc_check_chunk_size,
lfc_change_chunk_size,
NULL,
NULL,
NULL);
DefineCustomIntVariable("neon.file_cache_prewarm_limit",
@@ -649,19 +665,19 @@ lfc_get_state(size_t max_entries)
fcs = (FileCacheState*)palloc0(state_size);
SET_VARSIZE(fcs, state_size);
fcs->magic = FILE_CACHE_STATE_MAGIC;
fcs->chunk_size_log = lfc_chunk_size_log;
fcs->chunk_size_log = BLOCKS_PER_CHUNK_LOG;
fcs->n_chunks = n_entries;
bitmap = FILE_CACHE_STATE_BITMAP(fcs);
dlist_reverse_foreach(iter, &lfc_ctl->lru)
{
FileCacheEntry *entry = dlist_container(FileCacheEntry, list_node, iter.cur);
fcs->chunks[i] = entry->key;
for (int j = 0; j < lfc_blocks_per_chunk; j++)
fcs->chunks[i] = *file_cache_hash_get_key_for_entry(lfc_hash, entry);
for (int j = 0; j < BLOCKS_PER_CHUNK; j++)
{
if (GET_STATE(entry, j) != UNAVAILABLE)
{
BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j);
BITMAP_SET(bitmap, i*BLOCKS_PER_CHUNK + j);
n_pages += 1;
}
}
@@ -670,7 +686,7 @@ lfc_get_state(size_t max_entries)
}
Assert(i == n_entries);
fcs->n_pages = n_pages;
Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages);
Assert(pg_popcount((char*)bitmap, ((n_entries << BLOCKS_PER_CHUNK_LOG) + 7)/8) == n_pages);
elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages);
}
@@ -726,7 +742,7 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers)
}
fcs_chunk_size_log = fcs->chunk_size_log;
if (fcs_chunk_size_log > MAX_BLOCKS_PER_CHUNK_LOG)
if (fcs_chunk_size_log > BLOCKS_PER_CHUNK_LOG)
{
elog(ERROR, "LFC: Invalid chunk size log: %u", fcs->chunk_size_log);
}
@@ -945,7 +961,7 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
{
BufferTag tag;
FileCacheEntry *entry;
uint32 hash;
uint64 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
@@ -958,14 +974,14 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (LFC_ENABLED())
{
for (BlockNumber blkno = 0; blkno < nblocks; blkno += lfc_blocks_per_chunk)
for (BlockNumber blkno = 0; blkno < nblocks; blkno += BLOCKS_PER_CHUNK)
{
tag.blockNum = blkno;
hash = get_hash_value(lfc_hash, &tag);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
if (entry != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
@@ -990,7 +1006,7 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
FileCacheEntry *entry;
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
bool found = false;
uint32 hash;
uint64 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
@@ -1000,12 +1016,12 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
tag.blockNum = blkno - chunk_offs;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
if (LFC_ENABLED())
{
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
found = entry != NULL && GET_STATE(entry, chunk_offs) != UNAVAILABLE;
}
LWLockRelease(lfc_lock);
@@ -1024,7 +1040,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
uint32 chunk_offs;
int found = 0;
uint32 hash;
uint64 hash;
int i = 0;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
@@ -1037,7 +1053,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
@@ -1048,12 +1064,12 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
while (true)
{
int this_chunk = Min(nblocks - i, lfc_blocks_per_chunk - chunk_offs);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
int this_chunk = Min(nblocks - i, BLOCKS_PER_CHUNK - chunk_offs);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
if (entry != NULL)
{
for (; chunk_offs < lfc_blocks_per_chunk && i < nblocks; chunk_offs++, i++)
for (; chunk_offs < BLOCKS_PER_CHUNK && i < nblocks; chunk_offs++, i++)
{
if (GET_STATE(entry, chunk_offs) != UNAVAILABLE)
{
@@ -1079,7 +1095,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
*/
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno + i);
tag.blockNum = (blkno + i) - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
}
LWLockRelease(lfc_lock);
@@ -1128,7 +1144,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
int blocks_read = 0;
@@ -1154,9 +1170,9 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
while (nblocks > 0)
{
struct iovec iov[PG_IOV_MAX];
uint8 chunk_mask[MAX_BLOCKS_PER_CHUNK / 8] = {0};
uint8 chunk_mask[BLOCKS_PER_CHUNK / 8] = {0};
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int iteration_hits = 0;
int iteration_misses = 0;
uint64 io_time_us = 0;
@@ -1206,7 +1222,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
Assert(iov_last_used - first_block_in_chunk_read >= n_blocks_to_read);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
@@ -1219,13 +1235,13 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
return blocks_read;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (entry == NULL)
@@ -1296,7 +1312,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
if (iteration_hits != 0)
{
/* chunk offset (# of pages) into the LFC file */
off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk;
off_t first_read_offset = (off_t) entry_offset * BLOCKS_PER_CHUNK;
int nwrite = iov_last_used - first_block_in_chunk_read;
/* offset of first IOV */
first_read_offset += chunk_offs + first_block_in_chunk_read;
@@ -1373,14 +1389,14 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* Returns false if there are no unpinned entries and chunk can not be added.
*/
static bool
lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
lfc_init_new_entry(FileCacheEntry *entry)
{
/*-----------
* If the chunk wasn't already in the LFC then we have these
* options, in order of preference:
*
* Unless there is no space available, we can:
* 1. Use an entry from the `holes` list, and
* 1. Use an entry from the freelist, and
* 2. Create a new entry.
* We can always, regardless of space in the LFC:
* 3. evict an entry from LRU, and
@@ -1388,17 +1404,10 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
*/
if (lfc_ctl->used < lfc_ctl->limit)
{
if (!dlist_is_empty(&lfc_ctl->holes))
if (!freelist_is_empty())
{
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key,
hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
uint32 offset = freelist_pop();
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1427,7 +1436,7 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->lru));
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
@@ -1436,24 +1445,21 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
CriticalAssert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
hash_search_with_hash_value(lfc_hash, &victim->key,
victim->hash, HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, victim);
neon_log(DEBUG2, "Swap file cache page");
}
else
{
/* Can't add this chunk - we don't have the space for it */
hash_search_with_hash_value(lfc_hash, &entry->key, hash,
HASH_REMOVE, NULL);
file_cache_hash_remove_entry(lfc_hash, entry);
lfc_ctl->prewarm_canceled = true; /* cancel prewarm if LFC limit is reached */
return false;
}
entry->access_count = 1;
entry->hash = hash;
lfc_ctl->pinned += 1;
for (int i = 0; i < lfc_blocks_per_chunk; i++)
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
SET_STATE(entry, i, UNAVAILABLE);
return true;
@@ -1490,7 +1496,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
instr_time io_start, io_end;
@@ -1509,9 +1515,10 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1520,6 +1527,9 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
if (!freelist_prepare_pop())
goto retry;
lwlsn = neon_get_lwlsn(rinfo, forknum, blkno);
if (lwlsn > lsn)
@@ -1530,12 +1540,12 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (found)
{
@@ -1557,7 +1567,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry, hash))
if (!lfc_init_new_entry(entry))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1578,7 +1588,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwrite(lfc_desc, buffer, BLCKSZ,
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1640,7 +1650,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint32 hash;
uint64 hash;
uint64 generation;
uint32 entry_offset;
int buf_offset = 0;
@@ -1653,6 +1663,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1662,6 +1673,9 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
generation = lfc_ctl->generation;
if (!freelist_prepare_pop())
goto retry;
/*
* For every chunk that has blocks we're interested in, we
* 1. get the chunk header
@@ -1675,7 +1689,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
{
struct iovec iov[PG_IOV_MAX];
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
instr_time io_start, io_end;
ConditionVariable* cv;
@@ -1688,16 +1702,16 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
}
if (found)
@@ -1714,7 +1728,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry, hash))
if (!lfc_init_new_entry(entry))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1763,7 +1777,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwritev(lfc_desc, iov, blocks_in_chunk,
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1823,6 +1837,140 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/**** freelist management ****/
/*
* Prerequisites:
* - The caller is holding 'lfc_lock'. XXX
*/
static bool
freelist_prepare_pop(void)
{
/*
* If the in-memory freelist is empty, but there are more blocks available, load them.
*
* TODO: if there
*/
if (lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET)
{
uint32 freelist_head;
FreeListChunk *freelist_chunk;
size_t bytes_read;
LWLockRelease(lfc_lock);
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
if (!(lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET))
{
/* someone else did the work for us while we were not holding the lock */
LWLockRelease(lfc_freelist_lock);
return false;
}
freelist_head = lfc_ctl->freelist_head;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
bytes_read = 0;
while (bytes_read < BLOCKS_PER_CHUNK * BLCKSZ)
{
ssize_t rc;
rc = pread(lfc_desc, freelist_chunk, BLOCKS_PER_CHUNK * BLCKSZ - bytes_read, (off_t) freelist_head * BLOCKS_PER_CHUNK * BLCKSZ + bytes_read);
if (rc < 0)
{
lfc_disable("read freelist page");
return false;
}
bytes_read += rc;
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_generation != lfc_ctl->generation)
{
LWLockRelease(lfc_lock);
return false;
}
Assert(lfc_ctl->freelist_head == freelist_head);
Assert(lfc_ctl->num_free_pages == 0);
lfc_ctl->freelist_head = freelist_chunk->next;
lfc_ctl->num_free_pages = freelist_chunk->num_free_pages;
memcpy(lfc_ctl->free_pages, freelist_chunk->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
pfree(freelist_chunk);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return false;
}
return true;
}
/*
* Prerequisites:
* - The caller is holding 'lfc_lock' and 'lfc_freelist_lock'.
*
* Returns 'false' on error.
*/
static bool
freelist_push(uint32 offset)
{
Assert(lfc_ctl->freelist_size == FREELIST_ENTRIES_PER_CHUNK);
if (lfc_ctl->num_free_pages == lfc_ctl->freelist_size)
{
FreeListChunk *freelist_chunk;
struct iovec iov;
ssize_t rc;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
/* write the existing entries to the chunk on disk */
freelist_chunk->next = lfc_ctl->freelist_head;
freelist_chunk->num_free_pages = lfc_ctl->num_free_pages;
memcpy(freelist_chunk->free_pages, lfc_ctl->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
/* Use the passed-in offset to hold the freelist chunk itself */
iov.iov_base = freelist_chunk;
iov.iov_len = BLOCKS_PER_CHUNK * BLCKSZ;
rc = pg_pwritev_with_retry(lfc_desc, &iov, 1, (off_t) offset * BLOCKS_PER_CHUNK * BLCKSZ);
pfree(freelist_chunk);
if (rc < 0)
return false;
lfc_ctl->freelist_head = offset;
lfc_ctl->num_free_pages = 0;
}
else
{
lfc_ctl->free_pages[lfc_ctl->num_free_pages] = offset;
lfc_ctl->num_free_pages++;
}
return true;
}
static uint32
freelist_pop(void)
{
uint32 result;
/* The caller should've checked that the list is not empty */
Assert(lfc_ctl->num_free_pages > 0);
result = lfc_ctl->free_pages[lfc_ctl->num_free_pages - 1];
lfc_ctl->num_free_pages--;
return result;
}
static bool
freelist_is_empty(void)
{
return lfc_ctl->num_free_pages == 0;
}
typedef struct
{
TupleDesc tupdesc;
@@ -1919,7 +2067,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
break;
case 8:
key = "file_cache_chunk_size_pages";
value = lfc_blocks_per_chunk;
value = BLOCKS_PER_CHUNK;
break;
case 9:
key = "file_cache_chunks_pinned";
@@ -1990,7 +2138,6 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
@@ -2046,15 +2193,16 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (LFC_ENABLED())
{
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
@@ -2076,25 +2224,28 @@ local_cache_pages(PG_FUNCTION_ARGS)
* in the fctx->record structure.
*/
uint32 n = 0;
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
for (uint32 pos = 0; pos < num_buckets; pos++)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
const BufferTag *key = file_cache_hash_get_key_for_entry(lfc_hash, entry);
if (GET_STATE(entry, i) == AVAILABLE)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
fctx->record[n].pageoffs = entry->offset * BLOCKS_PER_CHUNK + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(*key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].forknum = key->forkNum;
fctx->record[n].blocknum = key->blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}

View File

@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
wp->propTermStartLsn = sk->voteResponse.flushLsn;
wp->donor = sk;
}
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
if (n_votes > 0)
appendStringInfoString(s, ", ");

View File

@@ -18,6 +18,11 @@ pub(super) async fn authenticate(
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");

View File

@@ -6,15 +6,16 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::WakeComputeBackend;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::stream::PqStream;
@@ -97,11 +98,15 @@ impl ConsoleRedirectBackend {
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
) -> auth::Result<(
ConsoleRedirectNodeInfo,
ComputeUserInfo,
Option<Vec<IpPattern>>,
)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(|(node_info, auth_info, user_info)| {
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
.map(|(node_info, user_info, ip_allowlist)| {
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
})
}
}
@@ -109,13 +114,17 @@ impl ConsoleRedirectBackend {
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
@@ -123,7 +132,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
// registering waiter can fail if we get unlucky with rng.
@@ -183,24 +192,10 @@ async fn authenticate(
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
let ssl_mode = if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
SslMode::Require
} else {
SslMode::Disable
};
let conn_info = compute::ConnectInfo {
host: db_info.host.into(),
port: db_info.port,
ssl_mode,
host_addr: None,
};
let auth_info =
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);
let user: RoleName = db_info.user.into();
let user_info = ComputeUserInfo {
@@ -214,12 +209,26 @@ async fn authenticate(
ctx.set_project(db_info.aux.clone());
info!("woken up a compute node");
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok((
NodeInfo {
conn_info,
config,
aux: db_info.aux,
},
auth_info,
user_info,
db_info.allowed_ips,
))
}

View File

@@ -1,12 +1,11 @@
use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use postgres_client::config::SslMode;
use tokio::sync::Semaphore;
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnectInfo;
use crate::compute::ConnCfg;
use crate::compute_ctl::ComputeCtlApi;
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
@@ -30,12 +29,7 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
conn_info: ConnectInfo {
host_addr: Some(postgres_addr.ip()),
host: postgres_addr.ip().to_string().into(),
port: postgres_addr.port(),
ssl_mode: SslMode::Disable,
},
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),

View File

@@ -25,7 +25,7 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::WakeComputeBackend;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
@@ -168,6 +168,8 @@ impl ComputeUserInfo {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
@@ -407,16 +409,23 @@ impl Backend<'_, ComputeUserInfo> {
}
#[async_trait::async_trait]
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
#[cfg(test)]

View File

@@ -169,6 +169,13 @@ pub(crate) async fn validate_password_and_exchange(
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;

View File

@@ -18,7 +18,6 @@ use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
@@ -101,13 +100,6 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
info!("invalidating endpoint access for `{endpoint_id}`");
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self

View File

@@ -24,6 +24,7 @@ use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type IpSubnetKey = IpNet;
@@ -496,8 +497,10 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
compute_config,
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
)
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;

View File

@@ -1,24 +1,21 @@
mod tls;
use std::fmt::Debug;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::net::SocketAddr;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
@@ -28,6 +25,7 @@ use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -40,7 +38,10 @@ pub(crate) enum ConnectionError {
Postgres(#[from] postgres_client::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] TlsError),
CouldNotConnect(#[from] io::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
@@ -72,7 +73,7 @@ impl UserFacingError for ConnectionError {
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
_ => COULD_NOT_CONNECT.to_owned(),
}
}
}
@@ -84,6 +85,7 @@ impl ReportableError for ConnectionError {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -94,85 +96,34 @@ impl ReportableError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
pub enum Auth {
/// Only used during console-redirect.
Password(Vec<u8>),
/// Used by sql-over-http, ws, tcp.
Scram(Box<ScramKeys>),
}
/// A config for authenticating to the compute node.
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
auth: Option<Auth>,
server_params: StartupMessageParams,
/// Console redirect sets user and database, we shouldn't re-use those from the params.
skip_db_user: bool,
}
/// Contains only the data needed to establish a secure connection to compute.
#[derive(Clone)]
pub struct ConnectInfo {
pub host_addr: Option<IpAddr>,
pub host: Host,
pub port: u16,
pub ssl_mode: SslMode,
}
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
/// Creation and initialization routines.
impl AuthInfo {
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
let mut server_params = StartupMessageParams::default();
server_params.insert("database", db);
server_params.insert("user", user);
Self {
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
server_params,
skip_db_user: true,
impl ConnCfg {
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}
/// Reuse password or auth keys from the other config.
pub(crate) fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
}
}
}
impl ConnectInfo {
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
config.ssl_mode(self.ssl_mode);
if let Some(host_addr) = self.host_addr {
config.set_host_addr(host_addr);
}
config
}
}
impl AuthInfo {
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
match &self.auth {
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
Some(Auth::Password(pw)) => config.password(pw),
None => &mut config,
};
for (k, v) in self.server_params.iter() {
config.set_param(k, v);
}
config
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(
@@ -181,26 +132,27 @@ impl AuthInfo {
arbitrary_params: bool,
) {
if !arbitrary_params {
self.server_params.insert("client_encoding", "UTF8");
self.set_param("client_encoding", "UTF8");
}
for (k, v) in params.iter() {
match k {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
"user" | "database" if self.skip_db_user => {}
"user" if self.user_is_set() => {}
"database" if self.db_is_set() => {}
"options" => {
if let Some(options) = filtered_options(v) {
self.server_params.insert(k, &options);
self.set_param(k, &options);
}
}
"user" | "database" | "application_name" | "replication" => {
self.server_params.insert(k, v);
self.set_param(k, v);
}
// if we allow arbitrary params, then we forward them through.
// this is a flag for a period of backwards compatibility
k if arbitrary_params => {
self.server_params.insert(k, v);
self.set_param(k, v);
}
_ => {}
}
@@ -208,13 +160,25 @@ impl AuthInfo {
}
}
impl ConnectInfo {
/// Establish a raw TCP+TLS connection to the compute node.
async fn connect_raw(
&self,
config: &ComputeConfig,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
impl std::ops::Deref for ConnCfg {
type Target = postgres_client::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
use postgres_client::config::Host;
// wrap TcpStream::connect with timeout
let connect_with_timeout = |addrs| {
@@ -244,32 +208,34 @@ impl ConnectInfo {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let port = self.port;
let host = &*self.host;
let port = self.0.get_port();
let host = self.0.get_host();
let addrs = match self.host_addr {
let host = match host {
Host::Tcp(host) => host.as_str(),
};
let addrs = match self.0.get_host_addr() {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port)).await?.collect(),
};
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
)),
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(TlsError::Connection(err))
Err(err)
}
}
}
}
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub(crate) stream:
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// PostgreSQL connection parameters.
pub(crate) params: std::collections::HashMap<String, String>,
/// Query cancellation token.
@@ -282,23 +248,28 @@ pub(crate) struct PostgresConnection {
_guage: NumDbConnectionsGuard<'static>,
}
impl ConnectInfo {
impl ConnCfg {
/// Connect to a corresponding compute node.
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
auth: &AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<PostgresConnection, ConnectionError> {
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
// we setup SSL early in `ConnectInfo::connect_raw`.
tmp_config.ssl_mode(SslMode::Disable);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let connection = tmp_config.connect_raw(stream, NoTls).await?;
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = self.0.connect_raw(stream, tls).await?;
drop(pause);
let RawConnection {
@@ -311,14 +282,13 @@ impl ConnectInfo {
tracing::Span::current().record("pid", tracing::field::display(process_id));
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
let MaybeTlsStream::Raw(stream) = stream.into_inner();
let stream = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.0.get_ssl_mode(),
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
@@ -329,11 +299,11 @@ impl ConnectInfo {
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.ssl_mode,
ssl_mode: self.0.get_ssl_mode(),
process_id,
secret_key,
},
self.host.to_string(),
host.to_string(),
user_info,
);

View File

@@ -1,63 +0,0 @@
use futures::FutureExt;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
Required,
}
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
}
}
}
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: &T,
host: &str,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
T: MakeTlsConnect<
S,
Error = InvalidDnsNameError,
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}

View File

@@ -210,20 +210,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, mut auth_info, user_info) = match backend
let (node_info, user_info, _ip_allowlist) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
auth_info.set_startup_params(&params, true);
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
auth: auth_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
&node_info,

View File

@@ -261,18 +261,24 @@ impl NeonControlPlaneClient {
Some(_) => SslMode::Require,
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => host.into(),
None => host.into(),
let host_name = match body.server_name {
Some(host) => host,
None => host.to_owned(),
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new(host_name, port);
if let Some(addr) = host_addr {
config.set_host_addr(addr);
}
config.ssl_mode(ssl_mode);
let node = NodeInfo {
conn_info: compute::ConnectInfo {
host_addr,
host,
port,
ssl_mode,
},
config,
aux: body.aux,
};

View File

@@ -6,7 +6,6 @@ use std::str::FromStr;
use std::sync::Arc;
use futures::TryFutureExt;
use postgres_client::config::SslMode;
use thiserror::Error;
use tokio_postgres::Client;
use tracing::{Instrument, error, info, info_span, warn};
@@ -15,7 +14,6 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
@@ -26,9 +24,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
use crate::{compute, scram};
#[derive(Debug, Error)]
enum MockApiError {
@@ -89,7 +87,8 @@ impl MockControlPlane {
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{role}' does not exist");
None
@@ -171,23 +170,25 @@ impl MockControlPlane {
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let port = self.endpoint.port().unwrap_or(5432);
let conn_info = match self.endpoint.host_str() {
None => ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
host: "localhost".into(),
port,
ssl_mode: SslMode::Disable,
},
Some(host) => ConnectInfo {
host_addr: IpAddr::from_str(host).ok(),
host: host.into(),
port,
ssl_mode: SslMode::Disable,
},
let mut config = match self.endpoint.host_str() {
None => {
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
config
}
Some(host) => {
let mut config = compute::ConnCfg::new(host.to_string(), port);
if let Ok(addr) = IpAddr::from_str(host) {
config.set_host_addr(addr);
}
config
}
};
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
conn_info,
config,
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -265,3 +266,12 @@ impl super::ControlPlaneApi for MockControlPlane {
self.do_wake_compute().map_ok(Cached::new_uncached).await
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {
let text = input.strip_prefix("md5")?;
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
Some(bytes)
}

View File

@@ -11,8 +11,8 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
@@ -39,6 +39,10 @@ pub mod mgmt;
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum AuthSecret {
#[cfg(any(test, feature = "testing"))]
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
@@ -59,9 +63,13 @@ pub(crate) struct AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
pub(crate) struct NodeInfo {
pub(crate) conn_info: compute::ConnectInfo,
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub(crate) config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
@@ -71,14 +79,26 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
auth: &compute::AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.conn_info
.connect(ctx, self.aux.clone(), auth, config, user_info)
self.config
.connect(ctx, self.aux.clone(), config, user_info)
.await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {
self.config.reuse_password(other.config);
}
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
};
}
}
#[derive(Copy, Clone, Default)]

View File

@@ -610,11 +610,11 @@ pub enum RedisEventsCount {
BranchCreated,
ProjectCreated,
CancelSession,
InvalidateRole,
InvalidateEndpoint,
InvalidateProject,
InvalidateProjects,
InvalidateOrg,
PasswordUpdate,
AllowedIpsUpdate,
AllowedVpcEndpointIdsUpdateForProjects,
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
BlockPublicOrVpcAccessUpdate,
}
pub struct ThreadPoolWorkers(usize);

View File

@@ -2,8 +2,8 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
@@ -13,6 +13,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
@@ -47,25 +48,34 @@ pub(crate) trait ConnectMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
pub(crate) trait WakeComputeBackend {
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
fn get_keys(&self) -> &ComputeCredentialKeys;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
pub(crate) struct TcpMechanism<'a> {
pub(crate) params_compat: bool,
/// KV-dictionary with PostgreSQL connection params.
pub(crate) params: &'a StartupMessageParams,
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) user_info: ComputeUserInfo,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
impl ConnectMechanism for TcpMechanism<'_> {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -80,18 +90,19 @@ impl ConnectMechanism for TcpMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(
node_info
.connect(ctx, &self.auth, config, self.user_info.clone())
.await,
)
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params, self.params_compat);
}
}
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,
@@ -103,9 +114,12 @@ where
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let node_info =
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
@@ -141,9 +155,14 @@ where
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
debug!("compute node's state has likely changed; requesting a wake-up");
invalidate_cache(node_info);
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);
node_info
};
// now that we have a new node, try connect to it repeatedly.

View File

@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncReadExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
@@ -53,28 +53,6 @@ impl fmt::Debug for ProtocolVersion {
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
@@ -88,38 +66,32 @@ macro_rules! read {
}};
}
/// Returns true if TLS is supported.
///
/// This is not cancel safe.
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let payload = StartupHeader {
len: 8.into(),
version: NEGOTIATE_SSL_CODE,
};
stream.write_all(payload.as_bytes()).await?;
stream.flush().await?;
// we expect back either `S` or `N` as a single byte.
let mut res = *b"0";
stream.read_exact(&mut res).await?;
debug_assert!(
res == *b"S" || res == *b"N",
"unexpected SSL negotiation response: {}",
char::from(res[0]),
);
// S for SSL.
Ok(res == *b"S")
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
@@ -592,9 +564,10 @@ mod tests {
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use super::ProtocolVersion;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
#[tokio::test]
async fn reject_large_startup() {
// we're going to define a v3.0 startup message with far too many parameters.

View File

@@ -358,22 +358,24 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: creds.info.clone(),
auth: auth_info,
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)

View File

@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
_ => false,
}
}
}

View File

@@ -19,7 +19,9 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
@@ -27,6 +29,7 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -69,14 +72,13 @@ struct ClientConfig<'a> {
hostname: &'a str,
}
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
Ok(crate::tls::postgres_rustls::make_tls_connect(
&self.config,
self.hostname,
)?)
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
}
}
@@ -495,6 +497,8 @@ impl ConnectMechanism for TestConnectMechanism {
x => panic!("expecting action {x:?}, connect is called instead"),
}
}
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
}
impl TestControlPlaneClient for TestConnectMechanism {
@@ -553,12 +557,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
ssl_mode: SslMode::Disable,
host_addr: None,
},
config: compute::ConnCfg::new("test".to_owned(), 5432),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -573,13 +572,16 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeUserInfo> {
) -> auth::Backend<'static, ComputeCredentials> {
auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
},
)
}

View File

@@ -8,7 +8,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pglb::connect_compute::WakeComputeBackend;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::proxy::retry::{retry_after, should_retry};
// Use macro to retain original callsite.
@@ -23,7 +23,7 @@ macro_rules! log_wake_compute_error {
};
}
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &RequestContext,
api: &B,

View File

@@ -3,12 +3,12 @@ use std::sync::Arc;
use futures::StreamExt;
use redis::aio::PubSub;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -27,37 +27,42 @@ struct NotificationHeader<'a> {
topic: &'a str,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(tag = "topic", content = "data")]
enum Notification {
pub(crate) enum Notification {
#[serde(
rename = "/account_settings_update",
alias = "/allowed_vpc_endpoints_updated_for_org",
rename = "/allowed_ips_updated",
deserialize_with = "deserialize_json_string"
)]
AccountSettingsUpdate(InvalidateAccount),
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/endpoint_settings_update",
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
EndpointSettingsUpdate(InvalidateEndpoint),
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/project_settings_update",
alias = "/allowed_ips_updated",
alias = "/block_public_or_vpc_access_updated",
alias = "/allowed_vpc_endpoints_updated_for_projects",
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
ProjectSettingsUpdate(InvalidateProject),
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/role_setting_update",
alias = "/password_updated",
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
RoleSettingUpdate(InvalidateRole),
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
)]
PasswordUpdate { password_update: PasswordUpdate },
#[serde(
other,
@@ -67,56 +72,28 @@ enum Notification {
UnknownTopic,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateEndpoint {
EndpointId(EndpointIdInt),
EndpointIds(Vec<EndpointIdInt>),
}
impl std::ops::Deref for InvalidateEndpoint {
type Target = [EndpointIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::EndpointId(id) => std::slice::from_ref(id),
Self::EndpointIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateProject {
ProjectId(ProjectIdInt),
ProjectIds(Vec<ProjectIdInt>),
}
impl std::ops::Deref for InvalidateProject {
type Target = [ProjectIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::ProjectId(id) => std::slice::from_ref(id),
Self::ProjectIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateAccount {
AccountId(AccountIdInt),
AccountIds(Vec<AccountIdInt>),
}
impl std::ops::Deref for InvalidateAccount {
type Target = [AccountIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::AccountId(id) => std::slice::from_ref(id),
Self::AccountIds(ids) => ids,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
account_id: AccountIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct InvalidateRole {
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
@@ -200,29 +177,41 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
tracing::debug!(?msg, "received a message");
match msg {
Notification::RoleSettingUpdate { .. }
| Notification::EndpointSettingsUpdate { .. }
| Notification::ProjectSettingsUpdate { .. }
| Notification::AccountSettingsUpdate { .. } => {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
let m = &Metrics::get().proxy.redis_events_count;
match msg {
Notification::RoleSettingUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateRole);
}
Notification::EndpointSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateEndpoint);
}
Notification::ProjectSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateProject);
}
Notification::AccountSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateOrg);
}
Notification::UnknownTopic => {}
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedIpsUpdate);
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
} else if matches!(
msg,
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
@@ -244,23 +233,30 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
match msg {
Notification::EndpointSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access(id)),
Notification::AccountSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
Notification::ProjectSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
Notification::RoleSettingUpdate(InvalidateRole {
project_id,
role_name,
}) => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate { project_id },
}
| Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
} => cache.invalidate_endpoint_access_for_project(project_id),
Notification::AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
} => cache.invalidate_endpoint_access_for_org(account_id),
Notification::AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects:
AllowedVpcEndpointsUpdatedForProjects { project_ids },
} => {
for project in project_ids {
cache.invalidate_endpoint_access_for_project(project);
}
}
Notification::PasswordUpdate {
password_update:
PasswordUpdate {
project_id,
role_name,
},
} => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::UnknownTopic => unreachable!(),
}
}
@@ -357,32 +353,11 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
);
Ok(())
}
#[test]
fn parse_multiple_projects() -> anyhow::Result<()> {
let project_id1: ProjectId = "new_project1".into();
let project_id2: ProjectId = "new_project2".into();
let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
let text = json!({
"type": "message",
"topic": "/allowed_vpc_endpoints_updated_for_projects",
"data": data,
"extre_fields": "something"
})
.to_string();
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
(&project_id1).into(),
(&project_id2).into()
]))
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate {
project_id: (&project_id).into()
}
}
);
Ok(())
@@ -404,10 +379,12 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::RoleSettingUpdate(InvalidateRole {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
})
Notification::PasswordUpdate {
password_update: PasswordUpdate {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
}
}
);
Ok(())

View File

@@ -21,8 +21,9 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
@@ -180,7 +181,7 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::pglb::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -188,7 +189,6 @@ impl PoolingBackend {
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
@@ -215,13 +215,16 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::pglb::connect_compute::connect_to_compute(
ctx,
@@ -302,13 +305,12 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let (key, jwk) = create_random_jwk();
let mut config = local_backend
.node_info
.conn_info
.to_postgres_client_config();
config
let config = node_info
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
@@ -320,7 +322,7 @@ impl PoolingBackend {
);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
let (client, connection) = config.connect(postgres_client::NoTls).await?;
drop(pause);
let pid = client.get_process_id();
@@ -334,7 +336,7 @@ impl PoolingBackend {
connection,
key,
conn_id,
local_backend.node_info.aux.clone(),
node_info.aux.clone(),
);
{
@@ -493,7 +495,6 @@ struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -511,20 +512,19 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
let res = config.connect(mk_tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -548,6 +548,8 @@ impl ConnectMechanism for TokioMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
@@ -571,20 +573,20 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let host_addr = node_info.config.get_host_addr();
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
let port = node_info.config.get_port();
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -607,6 +609,8 @@ impl ConnectMechanism for HyperMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(

View File

@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {

View File

@@ -2,11 +2,10 @@ use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
use rustls::pki_types::{InvalidDnsNameError, ServerName};
use rustls::ClientConfig;
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::config::ComputeConfig;
mod private {
use std::future::Future;
use std::io;
@@ -124,27 +123,36 @@ mod private {
}
}
impl<S> MakeTlsConnect<S> for ComputeConfig
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = InvalidDnsNameError;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
make_tls_connect(&self.tls, hostname)
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
}
}
pub fn make_tls_connect(
tls: &Arc<rustls::ClientConfig>,
hostname: &str,
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: tls.clone().into(),
})
})
}

View File

@@ -8,8 +8,8 @@ use std::error::Error as _;
use http_utils::error::HttpErrorBody;
use reqwest::{IntoUrl, Method, StatusCode};
use safekeeper_api::models::{
self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization,
TimelineCreateRequest, TimelineStatus,
self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest,
TimelineStatus,
};
use utils::id::{NodeId, TenantId, TimelineId};
use utils::logging::SecretString;
@@ -183,12 +183,6 @@ impl Client {
self.get(&uri).await
}
pub async fn status(&self) -> Result<SafekeeperStatus> {
let uri = format!("{}/v1/status", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn utilization(&self) -> Result<SafekeeperUtilization> {
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;

View File

@@ -395,8 +395,6 @@ pub enum TimelineError {
Cancelled(TenantTimelineId),
#[error("Timeline {0} was not found in global map")]
NotFound(TenantTimelineId),
#[error("Timeline {0} has been deleted")]
Deleted(TenantTimelineId),
#[error("Timeline {0} creation is in progress")]
CreationInProgress(TenantTimelineId),
#[error("Timeline {0} exists on disk, but wasn't loaded on startup")]

View File

@@ -78,13 +78,7 @@ impl GlobalTimelinesState {
Some(GlobalMapTimeline::CreationInProgress) => {
Err(TimelineError::CreationInProgress(*ttid))
}
None => {
if self.has_tombstone(ttid) {
Err(TimelineError::Deleted(*ttid))
} else {
Err(TimelineError::NotFound(*ttid))
}
}
None => Err(TimelineError::NotFound(*ttid)),
}
}

View File

@@ -1 +0,0 @@
ALTER TABLE nodes DROP COLUMN lifecycle;

View File

@@ -1 +0,0 @@
ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active';

View File

@@ -907,42 +907,6 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
json_response(StatusCode::OK, state.service.node_delete(node_id).await?)
}
async fn handle_tombstone_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let mut nodes = state.service.tombstone_list().await?;
nodes.sort_by_key(|n| n.get_id());
let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::<Vec<_>>();
json_response(StatusCode::OK, api_nodes)
}
async fn handle_tombstone_delete(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(
StatusCode::OK,
state.service.tombstone_delete(node_id).await?,
)
}
async fn handle_node_configure(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
@@ -2098,20 +2062,6 @@ pub fn make_router(
.post("/debug/v1/node/:node_id/drop", |r| {
named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop"))
})
.delete("/debug/v1/tombstone/:node_id", |r| {
named_request_span(
r,
handle_tombstone_delete,
RequestName("debug_v1_tombstone_delete"),
)
})
.get("/debug/v1/tombstone", |r| {
named_request_span(
r,
handle_tombstone_list,
RequestName("debug_v1_tombstone_list"),
)
})
.post("/debug/v1/tenant/:tenant_id/import", |r| {
named_request_span(
r,

View File

@@ -2,7 +2,7 @@ use std::str::FromStr;
use std::time::Duration;
use pageserver_api::controller_api::{
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest,
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest,
NodeSchedulingPolicy, TenantLocateResponseShard,
};
use pageserver_api::shard::TenantShardId;
@@ -29,7 +29,6 @@ pub(crate) struct Node {
availability: NodeAvailability,
scheduling: NodeSchedulingPolicy,
lifecycle: NodeLifecycle,
listen_http_addr: String,
listen_http_port: u16,
@@ -229,7 +228,6 @@ impl Node {
listen_pg_addr,
listen_pg_port,
scheduling: NodeSchedulingPolicy::Active,
lifecycle: NodeLifecycle::Active,
availability: NodeAvailability::Offline,
availability_zone_id,
use_https,
@@ -241,7 +239,6 @@ impl Node {
NodePersistence {
node_id: self.id.0 as i64,
scheduling_policy: self.scheduling.into(),
lifecycle: self.lifecycle.into(),
listen_http_addr: self.listen_http_addr.clone(),
listen_http_port: self.listen_http_port as i32,
listen_https_port: self.listen_https_port.map(|x| x as i32),
@@ -266,7 +263,6 @@ impl Node {
availability: NodeAvailability::Offline,
scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy)
.expect("Bad scheduling policy in DB"),
lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"),
listen_http_addr: np.listen_http_addr,
listen_http_port: np.listen_http_port as u16,
listen_https_port: np.listen_https_port.map(|x| x as u16),

View File

@@ -19,7 +19,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use itertools::Itertools;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy,
AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
};
use pageserver_api::models::{ShardImportStatus, TenantConfig};
@@ -102,7 +102,6 @@ pub(crate) enum DatabaseOperation {
UpdateNode,
DeleteNode,
ListNodes,
ListTombstones,
BeginShardSplit,
CompleteShardSplit,
AbortShardSplit,
@@ -358,8 +357,6 @@ impl Persistence {
}
/// When a node is first registered, persist it before using it for anything
/// If the provided node_id already exists, it will be error.
/// The common case is when a node marked for deletion wants to register.
pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> {
let np = &node.to_persistent();
self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| {
@@ -376,41 +373,19 @@ impl Persistence {
/// At startup, populate the list of nodes which our shards may be placed on
pub(crate) async fn list_nodes(&self) -> DatabaseResult<Vec<NodePersistence>> {
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
let nodes: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListNodes, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_nodes: loaded {} nodes", result.len());
tracing::info!("list_nodes: loaded {} nodes", nodes.len());
Ok(result)
}
pub(crate) async fn list_tombstones(&self) -> DatabaseResult<Vec<NodePersistence>> {
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListTombstones, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_tombstones: loaded {} nodes", result.len());
Ok(result)
Ok(nodes)
}
pub(crate) async fn update_node<V>(
@@ -429,7 +404,6 @@ impl Persistence {
Box::pin(async move {
let updated = diesel::update(nodes)
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.set(values)
.execute(conn)
.await?;
@@ -473,57 +447,6 @@ impl Persistence {
.await
}
/// Tombstone is a special state where the node is not deleted from the database,
/// but it is not available for usage.
/// The main reason for it is to prevent the flaky node to register.
pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.update_node(
del_node_id,
lifecycle.eq(String::from(NodeLifecycle::Deleted)),
)
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
// You can hard delete a node only if it has a tombstone.
// So we need to check if the node has lifecycle set to deleted.
let node_to_delete = nodes
.filter(node_id.eq(del_node_id.0 as i64))
.first::<NodePersistence>(conn)
.await
.optional()?;
if let Some(np) = node_to_delete {
let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| {
DatabaseError::Logical(format!(
"Node {} has invalid lifecycle: {}",
del_node_id, e
))
})?;
if lc != NodeLifecycle::Deleted {
return Err(DatabaseError::Logical(format!(
"Node {} was not soft deleted before, cannot hard delete it",
del_node_id
)));
}
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
}
Ok(())
})
})
.await
}
/// At startup, load the high level state for shards, such as their config + policy. This will
/// be enriched at runtime with state discovered on pageservers.
///
@@ -620,6 +543,21 @@ impl Persistence {
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
Ok(())
})
})
.await
}
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
/// the node that called /re-attach.
@@ -633,20 +571,6 @@ impl Persistence {
let updated = self
.with_measured_conn(DatabaseOperation::ReAttach, move |conn| {
Box::pin(async move {
// Check if the node is not marked as deleted
let deleted_node: i64 = nodes
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.count()
.get_result(conn)
.await?;
if deleted_node > 0 {
return Err(DatabaseError::Logical(format!(
"Node {} is marked as deleted, re-attach is not allowed",
input_node_id
)));
}
let rows_updated = diesel::update(tenant_shards)
.filter(generation_pageserver.eq(input_node_id.0 as i64))
.set(generation.eq(generation + 1))
@@ -2124,7 +2048,6 @@ pub(crate) struct NodePersistence {
pub(crate) listen_pg_port: i32,
pub(crate) availability_zone_id: String,
pub(crate) listen_https_port: Option<i32>,
pub(crate) lifecycle: String,
}
/// Tenant metadata health status that are stored durably.

View File

@@ -33,7 +33,6 @@ diesel::table! {
listen_pg_port -> Int4,
availability_zone_id -> Varchar,
listen_https_port -> Nullable<Int4>,
lifecycle -> Varchar,
}
}

View File

@@ -166,7 +166,6 @@ enum NodeOperations {
Register,
Configure,
Delete,
DeleteTombstone,
}
/// The leadership status for the storage controller process.
@@ -1108,8 +1107,7 @@ impl Service {
observed
}
/// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us
/// tenants from pageservers.
/// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers.
///
/// This is safe to run in the background, because if we don't have this TenantShardId in our map of
/// tenants, then it is probably something incompletely deleted before: we will not fight with any
@@ -6212,11 +6210,7 @@ impl Service {
}
}
fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict(
"failpoint".to_string()
)));
pausable_failpoint!("shard-split-pre-complete-pause");
pausable_failpoint!("shard-split-pre-complete");
// TODO: if the pageserver restarted concurrently with our split API call,
// the actual generation of the child shard might differ from the generation
@@ -6238,15 +6232,6 @@ impl Service {
let (response, child_locations, waiters) =
self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size);
// Notify all page servers to detach and clean up the old shards because they will no longer
// be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent
// Pageserver re-attach/startup.
let shards_to_cleanup = targets
.iter()
.map(|target| (target.parent_id, target.node.get_id()))
.collect();
self.cleanup_locations(shards_to_cleanup).await;
// Send compute notifications for all the new shards
let mut failed_notifications = Vec::new();
for (child_id, child_ps, stripe_size) in child_locations {
@@ -6924,7 +6909,7 @@ impl Service {
/// detaching or deleting it on pageservers. We do not try and re-schedule any
/// tenants that were on this node.
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
self.persistence.set_tombstone(node_id).await?;
self.persistence.delete_node(node_id).await?;
let mut locked = self.inner.write().unwrap();
@@ -7048,10 +7033,9 @@ impl Service {
// That is safe because in Service::spawn we only use generation_pageserver if it refers to a node
// that exists.
// 2. Actually delete the node from in-memory state and set tombstone to the database
// for preventing the node to register again.
// 2. Actually delete the node from the database and from in-memory state
tracing::info!("Deleting node from database");
self.persistence.set_tombstone(node_id).await?;
self.persistence.delete_node(node_id).await?;
Ok(())
}
@@ -7070,35 +7054,6 @@ impl Service {
Ok(nodes)
}
pub(crate) async fn tombstone_list(&self) -> Result<Vec<Node>, ApiError> {
self.persistence
.list_tombstones()
.await?
.into_iter()
.map(|np| Node::from_persistent(np, false))
.collect::<Result<Vec<_>, _>>()
.map_err(ApiError::InternalServerError)
}
pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> {
let _node_lock = trace_exclusive_lock(
&self.node_op_locks,
node_id,
NodeOperations::DeleteTombstone,
)
.await;
if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) {
self.persistence.delete_node(node_id).await?;
Ok(())
} else {
Err(ApiError::Conflict(format!(
"Node {} is in use, consider using tombstone API first",
node_id
)))
}
}
pub(crate) async fn get_node(&self, node_id: NodeId) -> Result<Node, ApiError> {
self.inner
.read()
@@ -7269,25 +7224,7 @@ impl Service {
};
match registration_status {
RegistrationStatus::New => {
self.persistence.insert_node(&new_node).await.map_err(|e| {
if matches!(
e,
crate::persistence::DatabaseError::Query(
diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)
)
) {
// The node can be deleted by tombstone API, and not show up in the list of nodes.
// If you see this error, check tombstones first.
ApiError::Conflict(format!("Node {} is already exists", new_node.get_id()))
} else {
ApiError::from(e)
}
})?;
}
RegistrationStatus::New => self.persistence.insert_node(&new_node).await?,
RegistrationStatus::NeedUpdate => {
self.persistence
.update_node_on_registration(

View File

@@ -77,7 +77,7 @@ class EndpointHttpClient(requests.Session):
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(prewarmed, timeout=60)
wait_until(prewarmed)
def offload_lfc(self):
url = f"http://localhost:{self.external_port}/lfc/offload"

View File

@@ -2054,14 +2054,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def tombstone_delete(self, node_id):
log.info(f"tombstone_delete({node_id})")
self.request(
"DELETE",
f"{self.api}/debug/v1/tombstone/{node_id}",
headers=self.headers(TokenScope.ADMIN),
)
def node_drain(self, node_id):
log.info(f"node_drain({node_id})")
self.request(
@@ -2118,14 +2110,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
)
return response.json()
def tombstone_list(self):
response = self.request(
"GET",
f"{self.api}/debug/v1/tombstone",
headers=self.headers(TokenScope.ADMIN),
)
return response.json()
def tenant_shard_dump(self):
"""
Debug listing API: dumps the internal map of tenant shards
@@ -4046,16 +4030,6 @@ def static_proxy(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
vanilla_pg.stop()
vanilla_pg.edit_hba(
[
"local all all trust",
"host all all 127.0.0.1/32 scram-sha-256",
"host all all ::1/128 scram-sha-256",
]
)
vanilla_pg.start()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()

View File

@@ -197,7 +197,8 @@ def print_gc_result(row: dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
(
" eligible: {layers_eligible}, not_updated: {layers_not_updated}, removed: {layers_removed}"
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
).format_map(row)
)

View File

@@ -87,9 +87,6 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build
# Set up pageserver for import
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
}
env = neon_env_builder.init_start()
env.pageserver.tenant_create(tenant)

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin):
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.create_branch("test_pageserver_restarts")
endpoint = env.endpoints.create_start("test_pageserver_restarts")
@@ -28,11 +28,7 @@ def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: Pg
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
thread = threading.Thread(
target=run_pgbench,
args=(endpoint.connstr(options="-cstatement_timeout=360s"),),
daemon=True,
)
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
thread.start()
for _ in range(n_restarts):

View File

@@ -19,15 +19,11 @@ TABLE_NAME = "neon_control_plane.endpoints"
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
@@ -64,9 +60,7 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
def query(status: int, query: str, *args):
@@ -81,8 +75,6 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
user="proxy",
password="password",
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
query(200, "select 1;") # should work now

View File

@@ -30,7 +30,6 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -199,7 +198,6 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder)
env.pageserver.allowed_errors.extend(
[
".*Timeline.*was cancelled.*",
".*Timeline.*has been deleted.*",
".*Timeline.*was not found.*",
]
)

View File

@@ -1836,90 +1836,3 @@ def test_sharding_gc(
shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"])
log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}")
assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn
def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder):
"""
Check that PageServer only deletes old shards after the split is committed such that it doesn't
have to download a lot of files during abort.
"""
DBNAME = "regression"
init_shard_count = 4
neon_env_builder.num_pageservers = init_shard_count
stripe_size = 32
env = neon_env_builder.init_start(
initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size
)
env.storage_controller.allowed_errors.extend(
[
# All split failures log a warning when they enqueue the abort operation
".*Enqueuing background abort.*",
# Tolerate any error logs that mention a failpoint
".*failpoint.*",
]
)
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
# Write some initial data.
endpoint.safe_psql(f"CREATE DATABASE {DBNAME}")
endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);")
for _ in range(1000):
endpoint.safe_psql(
"INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False
)
# Record how many bytes we've downloaded before the split.
def collect_downloaded_bytes() -> list[float | None]:
downloaded_bytes = []
for page_server in env.pageservers:
metric = page_server.http_client().get_metric_value(
"pageserver_remote_ondemand_downloaded_bytes_total"
)
downloaded_bytes.append(metric)
return downloaded_bytes
downloaded_bytes_before = collect_downloaded_bytes()
# Attempt to split the tenant, but fail the split before it completes.
env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)"))
with pytest.raises(StorageControllerApiException):
env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16)
# Wait until split is aborted.
def check_split_is_aborted():
tenants = env.storage_controller.tenant_list()
assert len(tenants) == 1
shards = tenants[0]["shards"]
assert len(shards) == 4
for shard in shards:
assert not shard["is_splitting"]
assert not shard["is_reconciling"]
# Make sure all new shards have been deleted.
valid_shards = 0
for ps in env.pageservers:
for tenant_dir in os.listdir(ps.workdir / "tenants"):
try:
tenant_shard_id = TenantShardId.parse(tenant_dir)
valid_shards += 1
assert tenant_shard_id.shard_count == 4
except ValueError:
log.info(f"{tenant_dir} is not valid tenant shard id")
assert valid_shards >= 4
wait_until(check_split_is_aborted)
endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False)
# Make sure we didn't download anything following the aborted split.
downloaded_bytes_after = collect_downloaded_bytes()
assert downloaded_bytes_before == downloaded_bytes_after
endpoint.stop_and_destroy()

View File

@@ -2956,7 +2956,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.allowed_errors.extend(
[".*Unexpected child shard count.*", ".*Enqueuing background abort.*"]
)
pause_failpoint = "shard-split-pre-complete-pause"
pause_failpoint = "shard-split-pre-complete"
env.storage_controller.configure_failpoints((pause_failpoint, "pause"))
split_fut = executor.submit(
@@ -3003,7 +3003,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.request(
"PUT",
f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints",
json=[{"name": pause_failpoint, "actions": "off"}],
json=[{"name": "shard-split-pre-complete", "actions": "off"}],
headers=env.storage_controller.headers(TokenScope.ADMIN),
)
@@ -3093,58 +3093,6 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
wait_until(reconfigure_node_again)
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_pageservers = 3
env = neon_env_builder.init_start()
def assert_nodes_count(n: int):
nodes = env.storage_controller.node_list()
assert len(nodes) == n
# Nodes count must remain the same before deletion
assert_nodes_count(3)
ps = env.pageservers[0]
env.storage_controller.node_delete(ps.id)
# After deletion, the node count must be reduced
assert_nodes_count(2)
# Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
log.info("Restarting tombstoned pageserver...")
ps.stop()
ps_start_fut = executor.submit(lambda: ps.start(await_active=False))
# After deleted pageserver restart, the node count must remain the same
assert_nodes_count(2)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id
env.storage_controller.tombstone_delete(ps.id)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 0
# Wait for the pageserver start operation to complete.
# If it fails with an exception, we try restarting the pageserver since the failure
# may be due to the storage controller refusing to register the node.
# However, if we get a TimeoutError that means the pageserver is completely hung,
# which is an unexpected failure mode that we'll let propagate up.
try:
ps_start_fut.result(timeout=20)
except TimeoutError:
raise
except Exception:
log.info("Restarting deleted pageserver...")
ps.restart()
# Finally, the node can be registered again after tombstone is deleted
wait_until(lambda: assert_nodes_count(3))
def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder):
"""
The storage controller is meant to handle the case where a timeline CRUD operation races

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import pytest
import requests
from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import (
PAGESERVER_GLOBAL_METRICS,
@@ -299,65 +299,6 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde
assert post_detach_samples == set()
def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder):
"""Tests that when a timeline is offloaded, the tenant specific metrics are not left behind"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
tenant_1, _ = env.create_tenant()
timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1)
timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1)
endpoint_tenant1 = env.endpoints.create_start(
"test_metrics_removed_after_offload_1", tenant_id=tenant_1
)
endpoint_tenant2 = env.endpoints.create_start(
"test_metrics_removed_after_offload_2", tenant_id=tenant_1
)
for endpoint in [endpoint_tenant1, endpoint_tenant2]:
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("CREATE TABLE t(key int primary key, value text)")
cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'")
cur.execute("SELECT sum(key) FROM t")
assert cur.fetchone() == (5000050000,)
endpoint.stop()
def get_ps_metric_samples_for_timeline(
tenant_id: TenantId, timeline_id: TimelineId
) -> list[Sample]:
ps_metrics = env.pageserver.http_client().get_metrics()
samples = []
for metric_name in ps_metrics.metrics:
for sample in ps_metrics.query_all(
name=metric_name,
filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)},
):
samples.append(sample)
return samples
for timeline in [timeline_1, timeline_2]:
pre_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}"
env.pageserver.http_client().timeline_archival_config(
tenant_1,
timeline,
state=TimelineArchivalState.ARCHIVED,
)
env.pageserver.http_client().timeline_offload(tenant_1, timeline)
post_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert post_offload_samples == set()
def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()

View File

@@ -433,7 +433,6 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -1935,7 +1934,6 @@ def test_membership_api(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)