Compare commits

..

18 Commits

Author SHA1 Message Date
Conrad Ludgate
4ada80d915 stash changes for proxy configuration 2025-06-07 21:10:41 +01:00
Conrad Ludgate
fd263a0c23 remove unused args and rearrange usage metrics args to be together 2025-06-07 20:24:44 +01:00
Conrad Ludgate
26dc39053e fixup! remove values that are never read for redis notifications 2025-06-07 20:11:02 +01:00
Conrad Ludgate
1f62ee5f5c fixup! move region to the parquet upload task, and not as part of the request context itself 2025-06-07 19:54:03 +01:00
Conrad Ludgate
e78254657a don't rate limit redis since we have switched to the batch queue system instead 2025-06-07 19:51:43 +01:00
Conrad Ludgate
640500aa6d remove some legacy from the early days of our redis support. 2025-06-07 19:46:46 +01:00
Conrad Ludgate
b0c712f63f move region to the parquet upload task, and not as part of the request context itself 2025-06-07 17:25:18 +01:00
Conrad Ludgate
f84e73c323 remove values that are never read for redis notifications 2025-06-07 17:10:34 +01:00
Erik Grinaker
3c7235669a pageserver: don't delete parent shard files until split is committed (#12146)
## Problem

If a shard split fails and must roll back, the tenant may hit a cold
start as the parent shard's files have already been removed from local
disk.

External contribution with minor adjustments, see
https://neondb.slack.com/archives/C08TE3203RQ/p1748246398269309.

## Summary of changes

Keep the parent shard's files on local disk until the split has been
committed, such that they are available if the spilt is rolled back. If
all else fails, the files will be removed on the next Pageserver
restart.

This should also be fine in a mixed version:

* New storcon, old Pageserver: the Pageserver will delete the files
during the split, storcon will log an error when the cleanup detach
fails.

* Old storcon, new Pageserver: the Pageserver will leave the parent's
files around until the next Pageserver restart.

The change looks good to me, but shard splits are delicate so I'd like
some extra eyes on this.
2025-06-06 15:55:14 +00:00
Conrad Ludgate
6dd84041a1 refactor and simplify the invalidation notification structure (#12154)
The current cache invalidation messages are far too specific. They
should be more generic since it only ends up triggering a
`GetEndpointAccessControl` message anyway.

Mappings:
* `/allowed_ips_updated`, `/block_public_or_vpc_access_updated`, and
`/allowed_vpc_endpoints_updated_for_projects` ->
`/project_settings_update`.
* `/allowed_vpc_endpoints_updated_for_org` ->
`/account_settings_update`.
* `/password_updated` -> `/role_setting_update`.

I've also introduced `/endpoint_settings_update`.

All message types support singular or multiple entries, which allows us
to simplify things both on our side and on cplane side.

I'm opening a PR to cplane to apply the above mappings, but for now
using the old phrases to allow both to roll out independently.

This change is inspired by my need to add yet another cached entry to
`GetEndpointAccessControl` for
https://github.com/neondatabase/cloud/issues/28333
2025-06-06 12:49:29 +00:00
Arpad Müller
df7e301a54 safekeeper: special error if a timeline has been deleted (#12155)
We might delete timelines on safekeepers before we are deleting them on
pageservers. This should be an exceptional situation, but can occur. As
the first step to improve behaviour here, emit a special error that is
less scary/obscure than "was not found in global map".

It is for example emitted when the pageserver tries to run
`IDENTIFY_SYSTEM` on a timeline that has been deleted on the safekeeper.

Found when analyzing the failure of
`test_scrubber_physical_gc_timeline_deletion` when enabling
`--timelines-onto-safekeepers` on the pytests.

Due to safekeeper restarts, there is no hard guarantee that we will keep
issuing this error, so we need to think of something better if we start
encountering this in staging/prod. But I would say that the introduction
of `--timelines-onto-safekeepers` in the pytests and into staging won't
change much about this: we are already deleting timelines from there. In
`test_scrubber_physical_gc_timeline_deletion`, we'd just be leaking the
timeline before on the safekeepers.

Part of #11712
2025-06-06 11:54:07 +00:00
Mikhail
470c7d5e0e endpoint_storage: default listen port, allow inline config (#12152)
Related: https://github.com/neondatabase/cloud/issues/27195
2025-06-06 11:48:01 +00:00
Conrad Ludgate
4d99b6ff4d [proxy] separate compute connect from compute authentication (#12145)
## Problem

PGLB/Neonkeeper needs to separate the concerns of connecting to compute,
and authenticating to compute.

Additionally, the code within `connect_to_compute` is rather messy,
spending effort on recovering the authentication info after
wake_compute.

## Summary of changes

Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only
returns `ConnectInfo` and `AuthInfo` is determined separately from the
`handshake`/`authenticate` process.

Additionally, `ConnectInfo::connect_raw` is in-charge or establishing
the TLS connection, and the `postgres_client::Config::connect_raw` is
configured to use `NoTls` which will force it to skip the TLS
negotiation. This should just work.
2025-06-06 10:29:55 +00:00
Alexander Sarantcev
590301df08 storcon: Introduce deletion tombstones to support flaky node scenario (#12096)
## Problem

Removed nodes can re-add themselves on restart if not properly
tombstoned. We need a mechanism (e.g. soft-delete flag) to prevent this,
especially in cases where the node is unreachable.

More details there: #12036

## Summary of changes

- Introduced `NodeLifecycle` enum to represent node lifecycle states.
- Added a string representation of `NodeLifecycle` to the `nodes` table.
- Implemented node removal using a tombstone mechanism.
- Introduced `/debug/v1/tombstone*` handlers to manage the tombstone
state.
2025-06-06 10:16:55 +00:00
Erik Grinaker
c511786548 pageserver: move spawn_grpc to GrpcPageServiceHandler::spawn (#12147)
Mechanical move, no logic changes.
2025-06-06 10:01:58 +00:00
Alex Chi Z.
fe31baf985 feat(build): add aws cli into the docker image (#12161)
## Problem

Makes it easier to debug AWS permission issues (i.e., storage scrubber)

## Summary of changes

Install awscliv2 into the docker image.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-06 09:38:58 +00:00
Alex Chi Z.
b23e75ebfe test(pageserver): ensure offload cleans up metrics (#12127)
Add a test to ensure timeline metrics are fully cleaned up after
offloading.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-06 06:50:54 +00:00
Arpad Müller
24d7c37e6e neon_local timeline import: create timelines on safekeepers (#12138)
neon_local's timeline import subcommand creates timelines manually, but
doesn't create them on the safekeepers. If a test then tries to open an
endpoint to read from the timeline, it will error in the new world with
`--timelines-onto-safekeepers`.

Therefore, if that flag is enabled, create the timelines on the
safekeepers.

Note that this import functionality is different from the fast import
feature (https://github.com/neondatabase/neon/issues/10188, #11801).

Part of #11670
As well as part of #11712
2025-06-05 18:53:14 +00:00
86 changed files with 2485 additions and 3442 deletions

130
Cargo.lock generated
View File

@@ -1086,25 +1086,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
dependencies = [
"clap",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1231,7 +1212,7 @@ version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -1289,14 +1270,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"cbindgen",
"neon-shmem",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -1472,6 +1445,7 @@ dependencies = [
"regex",
"reqwest",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
"serde",
"serde_json",
@@ -1963,7 +1937,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
dependencies = [
"darling",
"either",
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -2527,18 +2501,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2751,12 +2713,6 @@ dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3693,7 +3649,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3755,7 +3711,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3843,11 +3799,7 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
name = "neon-shmem"
version = "0.1.0"
dependencies = [
"criterion",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"rustc-hash 1.1.0",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -5141,7 +5093,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5162,7 +5114,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5287,7 +5239,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5321,6 +5273,7 @@ dependencies = [
"tokio-rustls 0.26.2",
"tokio-tungstenite 0.21.0",
"tokio-util",
"toml",
"tracing",
"tracing-log",
"tracing-opentelemetry",
@@ -5391,12 +5344,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5421,16 +5368,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5451,16 +5388,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5479,15 +5406,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5498,16 +5416,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6994,7 +6902,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"rustversion",
@@ -8293,15 +8201,6 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8659,15 +8558,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"

View File

@@ -44,7 +44,6 @@ members = [
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -252,7 +251,6 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
@@ -280,7 +278,6 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -110,6 +110,19 @@ RUN set -e \
# System postgres for use with client libraries (e.g. in storage controller)
postgresql-15 \
openssl \
unzip \
curl \
&& ARCH=$(uname -m) \
&& if [ "$ARCH" = "x86_64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
elif [ "$ARCH" = "aarch64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
else \
echo "Unsupported architecture: $ARCH" && exit 1; \
fi \
&& unzip awscliv2.zip \
&& ./aws/install \
&& rm -rf aws awscliv2.zip \
&& rm -f /etc/apt/apt.conf.d/80-retries \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& useradd -d /data neon \

View File

@@ -18,12 +18,10 @@ ifeq ($(BUILD_TYPE),release)
PG_LDFLAGS = $(LDFLAGS)
# Unfortunately, `--profile=...` is a nightly feature
CARGO_BUILD_FLAGS += --release
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -182,16 +180,11 @@ postgres-check-%: postgres-%
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
+@echo "Compiling communicator $*"
$(CARGO_CMD_PREFIX) cargo build -p communicator $(CARGO_BUILD_FLAGS)
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
LIBCOMMUNICATOR_PATH=$(NEON_CARGO_ARTIFACT_TARGET_DIR) \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \

View File

@@ -36,6 +36,7 @@ pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
safekeeper_client.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true
http-utils.workspace = true

View File

@@ -45,7 +45,7 @@ use pageserver_api::models::{
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::SafekeeperGeneration;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -1255,6 +1255,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
pageserver
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
.await?;
if env.storage_controller.timelines_onto_safekeepers {
println!("Creating timeline on safekeeper ...");
let timeline_info = pageserver
.timeline_info(
TenantShardId::unsharded(tenant_id),
timeline_id,
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
)
.await?;
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
let default_host = default_sk
.conf
.listen_addr
.clone()
.unwrap_or_else(|| "localhost".to_string());
let mconf = safekeeper_api::membership::Configuration {
generation: SafekeeperGeneration::new(1),
members: safekeeper_api::membership::MemberSet {
m: vec![SafekeeperId {
host: default_host,
id: default_sk.conf.id,
pg_port: default_sk.conf.pg_port,
}],
},
new_members: None,
};
let pg_version = args.pg_version * 10000;
let req = safekeeper_api::models::TimelineCreateRequest {
tenant_id,
timeline_id,
mconf,
pg_version,
system_id: None,
wal_seg_size: None,
start_lsn: timeline_info.last_record_lsn,
commit_lsn: None,
};
default_sk.create_timeline(&req).await?;
}
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
println!("Done");
}

View File

@@ -635,4 +635,16 @@ impl PageServerNode {
Ok(())
}
pub async fn timeline_info(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
) -> anyhow::Result<TimelineInfo> {
let timeline_info = self
.http_client
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
.await?;
Ok(timeline_info)
}
}

View File

@@ -6,7 +6,6 @@
//! .neon/safekeepers/<safekeeper id>
//! ```
use std::error::Error as _;
use std::future::Future;
use std::io::Write;
use std::path::PathBuf;
use std::time::Duration;
@@ -14,9 +13,9 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use http_utils::error::HttpErrorBody;
use postgres_connection::PgConnectionConfig;
use reqwest::{IntoUrl, Method};
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::id::NodeId;
@@ -35,25 +34,14 @@ pub enum SafekeeperHttpError {
type Result<T> = result::Result<T, SafekeeperHttpError>;
pub(crate) trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
}
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(SafekeeperHttpError::Response(
match self.json::<HttpErrorBody>().await {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
))
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
use mgmt_api::Error::*;
match err {
ApiError(_, str) => SafekeeperHttpError::Response(str),
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
}
}
@@ -70,9 +58,8 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: reqwest::Client,
pub http_client: mgmt_api::Client,
pub listen_addr: String,
pub http_base_url: String,
}
impl SafekeeperNode {
@@ -82,13 +69,14 @@ impl SafekeeperNode {
} else {
"127.0.0.1".to_string()
};
let jwt = None;
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
env: env.clone(),
http_client: env.create_http_client(),
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
listen_addr,
}
}
@@ -278,20 +266,19 @@ impl SafekeeperNode {
)
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
// TODO: authentication
//if self.env.auth_type == AuthType::NeonJWT {
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
//}
self.http_client.request(method, url)
pub async fn check_status(&self) -> Result<()> {
self.http_client
.status()
.await
.map_err(err_from_client_err)?;
Ok(())
}
pub async fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()
.await?
.error_from_body()
.await?;
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
self.http_client
.create_timeline(req)
.await
.map_err(err_from_client_err)?;
Ok(())
}
}

View File

@@ -61,10 +61,16 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
// Set a node status as deleted.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Delete a tombstone of node from the storage controller.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
},
/// Modify a tenant's policies in the storage controller
TenantPolicy {
#[arg(long)]
@@ -82,6 +88,8 @@ enum Command {
},
/// List nodes known to the storage controller
Nodes {},
/// List soft deleted nodes known to the storage controller
NodeTombstones {},
/// List tenants known to the storage controller
Tenants {
/// If this field is set, it will list the tenants on a specific node
@@ -900,6 +908,39 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("debug/v1/tombstone/{node_id}"),
None,
)
.await?;
}
Command::NodeTombstones {} => {
let mut resp = storcon_client
.dispatch::<(), Vec<NodeDescribeResponse>>(
Method::GET,
"debug/v1/tombstone".to_string(),
None,
)
.await?;
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
let mut table = comfy_table::Table::new();
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
for node in resp {
table.add_row([
format!("{}", node.id),
node.listen_http_addr,
node.availability_zone_id,
format!("{:?}", node.scheduling),
format!("{:?}", node.availability),
]);
}
println!("{table}");
}
Command::TenantSetTimeBasedEviction {
tenant_id,
period,

View File

@@ -3,7 +3,8 @@
//! This service is deployed either as a separate component or as part of compute image
//! for large computes.
mod app;
use anyhow::Context;
use anyhow::{Context, bail};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -12,9 +13,14 @@ const fn max_upload_file_limit() -> usize {
100 * 1024 * 1024
}
const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
#[serde(default = "listen")]
listen: std::net::SocketAddr,
pemfile: camino::Utf8PathBuf,
#[serde(flatten)]
@@ -31,13 +37,21 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: endpoint_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
let config: Config = serde_json::from_str(&config).context("parsing config")?;
// Allow either passing filename or inline config (for k8s helm chart)
let args: Vec<String> = std::env::args().skip(1).collect();
let config: Config = if args.len() == 1 && args[0].ends_with(".json") {
info!("Reading config from {}", args[0]);
let config = std::fs::read_to_string(args[0].clone())?;
serde_json::from_str(&config).context("parsing config")?
} else if !args.is_empty() && args[0].starts_with("--config=") {
info!("Reading inline config");
let config = args.join(" ");
let config = config.strip_prefix("--config=").unwrap();
serde_json::from_str(config).context("parsing config")?
} else {
bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON");
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());

View File

@@ -6,20 +6,8 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
rustc-hash = { version = "2.1.1" }
[dev-dependencies]
criterion = { workspace = true, features = ["html_reports"] }
rand = "0.9.1"
rand_distr = "0.5.1"
xxhash-rust = { version = "0.8.15", features = ["xxh3"] }
ahash.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[[bench]]
name = "hmap_resize"
harness = false

View File

@@ -1,438 +0,0 @@
//! Hash table implementation on top of 'shmem'
//!
//! Features required in the long run by the communicator project:
//!
//! [X] Accessible from both Postgres processes and rust threads in the communicator process
//! [X] Low latency
//! [ ] Scalable to lots of concurrent accesses (currently relies on caller for locking)
//! [ ] Resizable
use std::fmt::Debug;
use std::hash::{Hash, Hasher, BuildHasher};
use std::mem::MaybeUninit;
use rustc_hash::FxBuildHasher;
use crate::shmem::ShmemHandle;
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
mod optim;
use core::{CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry};
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
// Hash table can be allocated in a fixed memory area, or in a resizeable ShmemHandle.
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_size: usize,
hasher: S,
num_buckets: u32,
}
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
}
unsafe impl<'a, K: Sync, V: Sync, S> Sync for HashMapAccess<'a, K, V, S> {}
unsafe impl<'a, K: Send, V: Send, S> Send for HashMapAccess<'a, K, V, S> {}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
pub fn with_hasher(self, hasher: S) -> HashMapInit<'a, K, V, S> {
Self { hasher, ..self }
}
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
let mut ptr: *mut u8 = self.shared_ptr.cast();
let end_ptr: *mut u8 = unsafe { ptr.add(self.shared_size) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::LinkedKey<K>>())) };
let keys_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::LinkedKey<K>>() * self.num_buckets as usize) };
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<Option<V>>())) };
let vals_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<Option<V>>() * self.num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let keys =
unsafe { std::slice::from_raw_parts_mut(keys_ptr.cast(), self.num_buckets as usize) };
let vals =
unsafe { std::slice::from_raw_parts_mut(vals_ptr.cast(), self.num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(keys, vals, dictionary);
unsafe {
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
}
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
hasher: self.hasher,
}
}
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
// no difference to attach_writer currently
self.attach_writer()
}
}
/// This is stored in the shared memory area
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// HashMapShared
/// [buckets]
/// [dictionary]
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
struct HashMapShared<'a, K, V> {
inner: CoreHashMap<'a, K, V>
}
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
K: Clone + Hash + Eq
{
pub fn with_fixed(
num_buckets: u32,
area: &'a mut [MaybeUninit<u8>],
) -> HashMapInit<'a, K, V> {
Self {
num_buckets,
shmem_handle: None,
shared_ptr: area.as_mut_ptr().cast(),
shared_size: area.len(),
hasher: rustc_hash::FxBuildHasher::default(),
}
}
/// Initialize a new hash map in the given shared memory area
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
Self {
num_buckets,
shared_ptr: shmem.data_ptr.as_ptr().cast(),
shmem_handle: Some(shmem),
shared_size: size,
hasher: rustc_hash::FxBuildHasher::default()
}
}
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
let max_size = Self::estimate_size(max_buckets);
let shmem = ShmemHandle::new(name, size, max_size)
.expect("failed to make shared memory area");
Self {
num_buckets,
shared_ptr: shmem.data_ptr.as_ptr().cast(),
shmem_handle: Some(shmem),
shared_size: size,
hasher: rustc_hash::FxBuildHasher::default()
}
}
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> HashMapInit<'a, K, V> {
use std::sync::atomic::{AtomicUsize, Ordering};
const COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format!("neon_shmem_hmap{}", val);
Self::new_resizeable_named(num_buckets, max_buckets, &name)
}
}
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
where
K: Clone + Hash + Eq,
{
pub fn get_hash_value(&self, key: &K) -> u64 {
self.hasher.hash_one(key)
}
pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_with_hash(key, hash)
}
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_with_hash(key, hash)
}
pub fn remove_with_hash(&mut self, key: &K, hash: u64) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
match map.inner.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => {
e.remove();
}
Entry::Vacant(_) => {}
};
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_at_bucket(pos)
}
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map. (An Iterator might be nicer. The communicator's
/// clock algorithm needs to _slowly_ iterate through all buckets with its clock hand,
/// without holding a lock. If we switch to an Iterator, it must not hold the lock.)
pub fn get_at_bucket(&self, pos: usize) -> Option<(&K, &V)> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
if pos >= map.inner.keys.len() {
return None;
}
let key = &map.inner.keys[pos];
key.inner.as_ref().map(|k| (k, map.inner.vals[pos].as_ref().unwrap()))
}
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let origin = map.inner.vals.as_ptr();
let idx = (val_ptr as usize - origin as usize) / (size_of::<V>() as usize);
assert!(idx < map.inner.vals.len());
idx
}
// for metrics
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.buckets_in_use as usize
}
pub fn clear(&mut self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
inner.clear()
}
/// Helper function that abstracts the common logic between growing and shrinking.
/// The only significant difference in the rehashing step is how many buckets to rehash.
fn rehash_dict(
&mut self,
inner: &mut CoreHashMap<'a, K, V>,
keys_ptr: *mut core::LinkedKey<K>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
inner.free_head = INVALID_POS;
// Recalculate the dictionary
let keys;
let dictionary;
unsafe {
let keys_end_ptr = keys_ptr.add(num_buckets as usize);
let buckets_end_ptr: *mut u8 = (keys_end_ptr as *mut u8)
.add(size_of::<Option<V>>() * num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
keys = std::slice::from_raw_parts_mut(keys_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = INVALID_POS;
}
for i in 0..rehash_buckets as usize {
if keys[i].inner.is_none() {
keys[i].next = inner.free_head;
inner.free_head = i as u32;
continue;
}
let hash = self.hasher.hash_one(&keys[i].inner.as_ref().unwrap());
let pos: usize = (hash % dictionary.len() as u64) as usize;
keys[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.keys = keys;
}
/// Rehash the map. Intended for benchmarking only.
pub fn shuffle(&mut self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let num_buckets = inner.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { (self.shared_ptr as *mut u8).add(size_bytes) };
let keys_ptr = inner.keys.as_mut_ptr();
self.rehash_dict(inner, keys_ptr, end_ptr, num_buckets, num_buckets);
}
// /// Grow
// ///
// /// 1. grow the underlying shared memory area
// /// 2. Initialize new buckets. This overwrites the current dictionary
// /// 3. Recalculate the dictionary
// pub fn grow(&mut self, num_buckets: u32) -> Result<(), crate::shmem::Error> {
// let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
// let inner = &mut map.inner;
// let old_num_buckets = inner.buckets.len() as u32;
// if num_buckets < old_num_buckets {
// panic!("grow called with a smaller number of buckets");
// }
// if num_buckets == old_num_buckets {
// return Ok(());
// }
// let shmem_handle = self
// .shmem_handle
// .as_ref()
// .expect("grow called on a fixed-size hash table");
// let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
// shmem_handle.set_size(size_bytes)?;
// let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// // Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites
// // the dictionary!
// let keys_ptr = inner.keys.as_mut_ptr();
// unsafe {
// for i in old_num_buckets..num_buckets {
// let bucket_ptr = buckets_ptr.add(i as usize);
// bucket_ptr.write(core::Bucket {
// next: if i < num_buckets-1 {
// i as u32 + 1
// } else {
// inner.free_head
// },
// prev: if i > 0 {
// PrevPos::Chained(i as u32 - 1)
// } else {
// PrevPos::First(INVALID_POS)
// },
// inner: None,
// });
// }
// }
// self.rehash_dict(inner, keys_ptr, end_ptr, num_buckets, old_num_buckets);
// inner.free_head = old_num_buckets;
// Ok(())
// }
// /// Begin a shrink, limiting all new allocations to be in buckets with index less than `num_buckets`.
// pub fn begin_shrink(&mut self, num_buckets: u32) {
// let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
// if num_buckets > map.inner.get_num_buckets() as u32 {
// panic!("shrink called with a larger number of buckets");
// }
// _ = self
// .shmem_handle
// .as_ref()
// .expect("shrink called on a fixed-size hash table");
// map.inner.alloc_limit = num_buckets;
// }
// /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
// pub fn finish_shrink(&mut self) -> Result<(), crate::shmem::Error> {
// let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
// let inner = &mut map.inner;
// if !inner.is_shrinking() {
// panic!("called finish_shrink when no shrink is in progress");
// }
// let num_buckets = inner.alloc_limit;
// if inner.get_num_buckets() == num_buckets as usize {
// return Ok(());
// }
// for i in (num_buckets as usize)..inner.buckets.len() {
// if inner.buckets[i].inner.is_some() {
// // TODO(quantumish) Do we want to treat this as a violation of an invariant
// // or a legitimate error the caller can run into? Originally I thought this
// // could return something like a UnevictedError(index) as soon as it runs
// // into something (that way a caller could clear their soon-to-be-shrinked
// // buckets by repeatedly trying to call `finish_shrink`).
// //
// // Would require making a wider error type enum with this and shmem errors.
// panic!("unevicted entries in shrinked space")
// }
// match inner.buckets[i].prev {
// PrevPos::First(_) => {
// let next_pos = inner.buckets[i].next;
// inner.free_head = next_pos;
// if next_pos != INVALID_POS {
// inner.buckets[next_pos as usize].prev = PrevPos::First(INVALID_POS);
// }
// },
// PrevPos::Chained(j) => {
// let next_pos = inner.buckets[i].next;
// inner.buckets[j as usize].next = next_pos;
// if next_pos != INVALID_POS {
// inner.buckets[next_pos as usize].prev = PrevPos::Chained(j);
// }
// }
// }
// }
// let shmem_handle = self
// .shmem_handle
// .as_ref()
// .expect("shrink called on a fixed-size hash table");
// let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
// shmem_handle.set_size(size_bytes)?;
// let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// let buckets_ptr = inner.buckets.as_mut_ptr();
// self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
// inner.alloc_limit = INVALID_POS;
// Ok(())
// }
}

View File

@@ -1,247 +0,0 @@
//! Simple hash table with chaining
//!
//! # Resizing
//!
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
pub(crate) const INVALID_POS: u32 = u32::MAX;
pub(crate) struct LinkedKey<K> {
pub(crate) inner: Option<K>,
pub(crate) next: u32,
}
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dictionary: &'a mut [u32],
pub(crate) keys: &'a mut [LinkedKey<K>],
pub(crate) vals: &'a mut [Option<V>],
/// Head of the freelist.
pub(crate) free_head: u32,
pub(crate) _user_list_head: u32,
/// Maximum index of a bucket allowed to be allocated. INVALID_POS if no limit.
pub(crate) alloc_limit: u32,
// metrics
pub(crate) buckets_in_use: u32,
}
#[derive(Debug)]
pub struct FullError();
impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V>
where
K: Clone + Hash + Eq,
{
const FILL_FACTOR: f32 = 0.60;
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += (size_of::<LinkedKey<K>>() + size_of::<Option<V>>())
* num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
keys: &'a mut [MaybeUninit<LinkedKey<K>>],
vals: &'a mut [MaybeUninit<Option<V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> CoreHashMap<'a, K, V> {
// Initialize the buckets
for i in 0..keys.len() {
keys[i].write(LinkedKey {
next: if i < keys.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
for i in 0..vals.len() {
vals[i].write(None);
}
// Initialize the dictionary
for i in 0..dictionary.len() {
dictionary[i].write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let keys =
unsafe { std::slice::from_raw_parts_mut(keys.as_mut_ptr().cast(), keys.len()) };
let vals =
unsafe { std::slice::from_raw_parts_mut(vals.as_mut_ptr().cast(), vals.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
CoreHashMap {
dictionary,
keys,
vals,
free_head: 0,
buckets_in_use: 0,
_user_list_head: INVALID_POS,
alloc_limit: INVALID_POS,
}
}
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let keylink = &self.keys[next as usize];
let bucket_key = keylink.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(self.vals[next as usize].as_ref().unwrap());
}
next = keylink.next;
}
}
// all updates are done through Entry
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let dict_pos = hash as usize % self.dictionary.len();
let first = self.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let keylink = &mut self.keys[next as usize];
let bucket_key = keylink.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map: self,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if keylink.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = keylink.next;
}
}
pub fn get_num_buckets(&self) -> usize {
self.keys.len()
}
pub fn is_shrinking(&self) -> bool {
self.alloc_limit != INVALID_POS
}
/// Clears all entries from the hashmap.
/// Does not reset any allocation limits, but does clear any entries beyond them.
pub fn clear(&mut self) {
for i in 0..self.keys.len() {
self.keys[i] = LinkedKey {
next: if i < self.keys.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
}
}
for i in 0..self.vals.len() {
self.vals[i] = None;
}
for i in 0..self.dictionary.len() {
self.dictionary[i] = INVALID_POS;
}
self.buckets_in_use = 0;
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
if pos >= self.keys.len() {
return None;
}
let entry = self.keys[pos].inner.as_ref();
match entry {
Some(key) => Some(OccupiedEntry {
_key: key.clone(),
bucket_pos: pos as u32,
prev_pos: PrevPos::Unknown,
map: self,
}),
_ => None,
}
}
/// Find the position of an unused bucket via the freelist and initialize it.
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
// Find the first bucket we're *allowed* to use.
let mut prev = PrevPos::First(self.free_head);
while pos != INVALID_POS && pos >= self.alloc_limit {
let keylink = &mut self.keys[pos as usize];
prev = PrevPos::Chained(pos);
pos = keylink.next;
}
if pos == INVALID_POS {
return Err(FullError());
}
// Repair the freelist.
match prev {
PrevPos::First(_) => {
let next_pos = self.keys[pos as usize].next;
self.free_head = next_pos;
}
PrevPos::Chained(p) => if p != INVALID_POS {
let next_pos = self.keys[pos as usize].next;
self.keys[p as usize].next = next_pos;
},
PrevPos::Unknown => unreachable!()
}
// Initialize the bucket.
let keylink = &mut self.keys[pos as usize];
self.buckets_in_use += 1;
keylink.next = INVALID_POS;
keylink.inner = Some(key);
self.vals[pos as usize] = Some(value);
return Ok(pos);
}
}

View File

@@ -1,107 +0,0 @@
//! Like std::collections::hash_map::Entry;
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
/// Helper enum representing the previous position within a hashmap chain.
#[derive(Clone, Copy)]
pub(crate) enum PrevPos {
/// Starting index within the dictionary.
First(u32),
/// Regular index within the buckets.
Chained(u32),
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
Unknown,
}
impl PrevPos {
/// Unwrap an index from a `PrevPos::First`, panicking otherwise.
pub fn unwrap_first(&self) -> u32 {
match self {
Self::First(i) => *i,
_ => panic!("not first entry in chain")
}
}
}
pub struct OccupiedEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
pub(crate) prev_pos: PrevPos,
/// The position of the bucket in the CoreHashMap's buckets array.
pub(crate) bucket_pos: u32,
}
impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> {
pub fn get(&self) -> &V {
self.map.vals[self.bucket_pos as usize]
.as_ref()
.unwrap()
}
pub fn get_mut(&mut self) -> &mut V {
self.map.vals[self.bucket_pos as usize]
.as_mut()
.unwrap()
}
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.vals[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
let old_value = mem::replace(bucket.as_mut().unwrap(), value);
old_value
}
pub fn remove(self) -> V {
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let keylink = &mut self.map.keys[self.bucket_pos as usize];
// unlink it from the chain
match self.prev_pos {
PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = keylink.next,
PrevPos::Chained(bucket_pos) => {
self.map.keys[bucket_pos as usize].next = keylink.next
},
PrevPos::Unknown => panic!("can't safely remove entry with unknown previous entry"),
}
// and add it to the freelist
let keylink = &mut self.map.keys[self.bucket_pos as usize];
keylink.inner = None;
keylink.next = self.map.free_head;
let old_value = self.map.vals[self.bucket_pos as usize].take();
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
return old_value.unwrap();
}
}
pub struct VacantEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) key: K, // The key to insert
pub(crate) dict_pos: u32,
}
impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> {
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
self.map.keys[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
let result = self.map.vals[pos as usize].as_mut().unwrap();
return Ok(result);
}
}

View File

@@ -1,85 +0,0 @@
//! Adapted from https://github.com/jsnell/parallel-xxhash (TODO: license?)
use core::arch::x86::*;
const PRIME32_1: u32 = 2654435761;
const PRIME32_2: u32 = 2246822519;
const PRIME32_3: u32 = 3266489917;
const PRIME32_4: u32 = 668265263;
const PRIME32_5: u32 = 374761393;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
fn mm256_rol32<const r: u32>(x: __m256i) -> __m256i {
return _mm256_or_si256(_mm256_slli_epi32(x, r),
_mm256_srli_epi32(x, 32 - r));
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
fn mm256_fmix32(mut h: __m256i) -> __m256i {
h = _mm256_xor_si256(h, _mm256_srli_epi32(h, 15));
h = _mm256_mullo_epi32(h, _mm256_set1_epi32(PRIME32_2));
h = _mm256_xor_si256(h, _mm256_srli_epi32(h, 13));
h = _mm256_mullo_epi32(h, _mm256_set1_epi32(PRIME32_3));
h = _mm256_xor_si256(h, _mm256_srli_epi32(h, 16));
h
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
fn mm256_round(mut seed: __m256i, input: __m256i) -> __m256i {
seed = _mm256_add_epi32(
seed,
_mm256_mullo_epi32(input, _mm256_set1_epi32(PRIME32_2))
);
seed = mm256_rol32::<13>(seed);
seed = _mm256_mullo_epi32(seed, _mm256_set1_epi32(PRIME32_1));
seed
}
/// Computes xxHash for 8 keys of size 4*N bytes in column-major order.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
fn xxhash_many<const N: usize>(keys: *const u32, seed: u32) -> [u32; 8] {
let mut res = [0; 8];
let mut h = _mm256_set1_epi32(seed + PRIME32_5);
if (N >= 4) {
let mut v1 = _mm256_set1_epi32(seed + PRIME32_1 + PRIME32_2);
let mut v2 = _mm256_set1_epi32(seed + PRIME32_2);
let mut v3 = _mm256_set1_epi32(seed);
let mut v4 = _mm256_set1_eip32(seed - PRIME32_1);
let mut i = 0;
while i < (N & !3) {
let k1 = _mm256_loadu_si256(keys.add((i + 0) * 8).cast());
let k2 = _mm256_loadu_si256(keys.add((i + 1) * 8).cast());
let k3 = _mm256_loadu_si256(keys.add((i + 2) * 8).cast());
let k4 = _mm256_loadu_si256(keys.add((i + 3) * 8).cast());
v1 = mm256_round(v1, k1);
v2 = mm256_round(v2, k2);
v3 = mm256_round(v3, k3);
v4 = mm256_round(v4, k4);
i += 4;
}
h = mm256_rol32::<1>(v1) + mm256_rol32::<7>(v2) +
mm256_rol32::<12>(v3) + mm256_rol32::<18>(v4);
}
// Unneeded, keeps bitwise parity with xxhash though.
h = _m256_add_epi32(h, _mm256_set1_eip32(N * 4));
for i in -(N & 3)..0 {
let v = _mm256_loadu_si256(keys.add((N + i) * 8));
h = _mm256_add_epi32(
h,
_mm256_mullo_epi32(v, _mm256_set1_epi32(PRIME32_3))
);
h = _mm256_mullo_epi32(
mm256_rol32::<17>(h),
_mm256_set1_epi32(PRIME32_4)
);
}
_mm256_storeu_si256((&mut res as *mut _).cast(), mm256_fmix32(h));
res
}

View File

@@ -1,382 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::mem::uninitialized;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::Entry;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let mut w = HashMapInit::<TestKey, usize>::new_resizeable_named(
100000, 120000, "test_inserts"
).attach_writer();
for (idx, k) in keys.iter().enumerate() {
let hash = w.get_hash_value(&(*k).into());
let res = w.entry_with_hash((*k).into(), hash);
match res {
Entry::Occupied(mut e) => { e.insert(idx); }
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
},
};
}
for (idx, k) in keys.iter().enumerate() {
let hash = w.get_hash_value(&(*k).into());
let x = w.get_with_hash(&(*k).into(), hash);
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
let hash = map.get_hash_value(&op.0);
let entry = map.entry_with_hash(op.0, hash);
let hash_existing = match op.1 {
Some(new) => {
match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => { e.insert(new).unwrap(); None },
}
},
None => {
match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
},
};
assert_eq!(shadow_existing, hash_existing);
}
fn do_random_ops(
num_ops: usize,
size: u32,
del_prob: f64,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
rng: &mut rand::rngs::ThreadRng,
) {
for i in 0..num_ops {
let key: TestKey = ((rng.next_u32() % size) as u128).into();
let op = TestOp(key, if rng.random_bool(del_prob) { Some(i) } else { None });
apply_op(&op, writer, shadow);
}
}
fn do_deletes(
num_ops: usize,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
for i in 0..num_ops {
let (k, _) = shadow.pop_first().unwrap();
let hash = writer.get_hash_value(&k);
writer.remove_with_hash(&k, hash);
}
}
fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32
) {
writer.begin_shrink(to);
while writer.get_num_buckets_in_use() > to as usize {
let (k, _) = shadow.pop_first().unwrap();
let hash = writer.get_hash_value(&k);
let entry = writer.entry_with_hash(k, hash);
if let Entry::Occupied(mut e) = entry {
e.remove();
}
}
writer.finish_shrink().unwrap();
}
#[test]
fn random_ops() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
100000, 120000, "test_random"
).attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
}
}
}
#[test]
fn test_shuffle() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1000, 1200, "test_shuf"
).attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.shuffle();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_grow() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1000, 2000, "test_grow"
).attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.grow(1500).unwrap();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_shrink() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1500, 2000, "test_shrink"
).attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1500, 1000);
do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
assert!(writer.get_num_buckets_in_use() <= 1000);
}
#[test]
fn test_shrink_grow_seq() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1000, 20000, "test_grow_seq"
).attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500");
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
do_shrink(&mut writer, &mut shadow, 1500, 200);
do_deletes(100, &mut writer, &mut shadow);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_bucket_ops() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1000, 1200, "test_bucket_ops"
).attach_writer();
let hash = writer.get_hash_value(&1.into());
match writer.entry_with_hash(1.into(), hash) {
Entry::Occupied(mut e) => { e.insert(2); },
Entry::Vacant(e) => { e.insert(2).unwrap(); },
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(writer.get_with_hash(&1.into(), hash), Some(&2));
match writer.entry_with_hash(1.into(), hash) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
let pos = e.bucket_pos as usize;
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2)));
},
Entry::Vacant(_) => { panic!("Insert didn't affect entry"); },
}
writer.remove_with_hash(&1.into(), hash);
assert_eq!(writer.get_with_hash(&1.into(), hash), None);
}
#[test]
fn test_shrink_zero() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1500, 2000, "test_shrink_zero"
).attach_writer();
writer.begin_shrink(0);
for i in 0..1500 {
writer.entry_at_bucket(i).map(|x| x.remove());
}
writer.finish_shrink().unwrap();
assert_eq!(writer.get_num_buckets_in_use(), 0);
let hash = writer.get_hash_value(&1.into());
let entry = writer.entry_with_hash(1.into(), hash);
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_err());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
writer.grow(50).unwrap();
let entry = writer.entry_with_hash(1.into(), hash);
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_ok());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
}
#[test]
#[should_panic]
fn test_grow_oom() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1500, 2000, "test_grow_oom"
).attach_writer();
writer.grow(20000).unwrap();
}
#[test]
#[should_panic]
fn test_shrink_bigger() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1500, 2500, "test_shrink_bigger"
).attach_writer();
writer.begin_shrink(2000);
}
#[test]
#[should_panic]
fn test_shrink_early_finish() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1500, 2500, "test_shrink_early_finish"
).attach_writer();
writer.finish_shrink().unwrap();
}
#[test]
#[should_panic]
fn test_shrink_fixed_size() {
let mut area = [MaybeUninit::uninit(); 10000];
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
let mut writer = init_struct.attach_writer();
writer.begin_shrink(1);
}

View File

@@ -1,4 +1,418 @@
//! Shared memory utilities for neon communicator
pub mod hash;
pub mod shmem;
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,418 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -344,6 +344,35 @@ impl Default for ShardSchedulingPolicy {
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeLifecycle {
Active,
Deleted,
}
impl FromStr for NodeLifecycle {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"deleted" => Ok(Self::Deleted),
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
}
}
}
impl From<NodeLifecycle> for String {
fn from(value: NodeLifecycle) -> String {
use NodeLifecycle::*;
match value {
Active => "active",
Deleted => "deleted",
}
.to_string()
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeSchedulingPolicy {
Active,

View File

@@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>

View File

@@ -17,7 +17,6 @@ use crate::{Client, Connection, Error};
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
@@ -231,7 +230,7 @@ impl Config {
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: T,
tls: &T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,

View File

@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
pub async fn connect<T>(
mut tls: T,
tls: &T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where

View File

@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
use crate::membership::Configuration;
use crate::{ServerInfo, Term};
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct SafekeeperStatus {
pub id: NodeId,
}

View File

@@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::page_service::GrpcPageServiceHandler;
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
};
@@ -814,7 +815,7 @@ fn start_pageserver(
// necessary?
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),

View File

@@ -169,99 +169,6 @@ pub fn spawn(
Listener { cancel, task }
}
/// Spawns a gRPC server for the page service.
///
/// TODO: move this onto GrpcPageServiceHandler::spawn().
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn_grpc(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
impl Listener {
pub async fn stop_accepting(self) -> Connections {
self.cancel.cancel();
@@ -3366,6 +3273,101 @@ pub struct GrpcPageServiceHandler {
}
impl GrpcPageServiceHandler {
/// Spawns a gRPC server for the page service.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
listener,
)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
// Run the page service.
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]

View File

@@ -1671,7 +1671,12 @@ impl TenantManager {
}
}
// Phase 5: Shut down the parent shard, and erase it from disk
// Phase 5: Shut down the parent shard. We leave it on disk in case the split fails and we
// have to roll back to the parent shard, avoiding a cold start. It will be cleaned up once
// the storage controller commits the split, or if all else fails, on the next restart.
//
// TODO: We don't flush the ephemeral layer here, because the split is likely to succeed and
// catching up the parent should be reasonably quick. Consider using FreezeAndFlush instead.
let (_guard, progress) = completion::channel();
match parent.shutdown(progress, ShutdownMode::Hard).await {
Ok(()) => {}
@@ -1679,11 +1684,6 @@ impl TenantManager {
other.wait().await;
}
}
let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id);
let tmp_path = safe_rename_tenant_dir(&local_tenant_directory)
.await
.with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?;
self.background_purges.spawn(tmp_path);
fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!(
"failpoint"
@@ -1846,42 +1846,70 @@ impl TenantManager {
shutdown_all_tenants0(self.tenants).await
}
/// Detaches a tenant, and removes its local files asynchronously.
///
/// File removal is idempotent: even if the tenant has already been removed, this will still
/// remove any local files. This is used during shard splits, where we leave the parent shard's
/// files around in case we have to roll back the split.
pub(crate) async fn detach_tenant(
&self,
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
deletion_queue_client: &DeletionQueueClient,
) -> Result<(), TenantStateError> {
let tmp_path = self
if let Some(tmp_path) = self
.detach_tenant0(conf, tenant_shard_id, deletion_queue_client)
.await?;
self.background_purges.spawn(tmp_path);
.await?
{
self.background_purges.spawn(tmp_path);
}
Ok(())
}
/// Detaches a tenant. This renames the tenant directory to a temporary path and returns it,
/// allowing the caller to delete it asynchronously. Returns None if the dir is already removed.
async fn detach_tenant0(
&self,
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
deletion_queue_client: &DeletionQueueClient,
) -> Result<Utf8PathBuf, TenantStateError> {
) -> Result<Option<Utf8PathBuf>, TenantStateError> {
let tenant_dir_rename_operation = |tenant_id_to_clean: TenantShardId| async move {
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
if !tokio::fs::try_exists(&local_tenant_directory).await? {
// If the tenant directory doesn't exist, it's already cleaned up.
return Ok(None);
}
safe_rename_tenant_dir(&local_tenant_directory)
.await
.with_context(|| {
format!("local tenant directory {local_tenant_directory:?} rename")
})
.map(Some)
};
let removal_result = remove_tenant_from_memory(
let mut removal_result = remove_tenant_from_memory(
self.tenants,
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
// directory on disk anyway. For example, during shard splits, we shut down and remove the
// parent shard, but leave its directory on disk in case we have to roll back the split.
//
// TODO: it would be better to leave the parent shard attached until the split is committed.
// This will be needed by the gRPC page service too, such that a compute can continue to
// read from the parent shard until it's notified about the new child shards. See:
// <https://github.com/neondatabase/neon/issues/11728>.
if let Err(TenantStateError::SlotError(TenantSlotError::NotFound(_))) = removal_result {
removal_result = tenant_dir_rename_operation(tenant_shard_id)
.await
.map_err(TenantStateError::Other);
}
// Flush pending deletions, so that they have a good chance of passing validation
// before this tenant is potentially re-attached elsewhere.
deletion_queue_client.flush_advisory();

View File

@@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
/// Argument to [`Timeline::shutdown`].
#[derive(Debug, Clone, Copy)]
pub(crate) enum ShutdownMode {
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
/// take multiple seconds for a busy timeline.
///
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
/// the call to [`Timeline::shutdown`].

View File

@@ -1,5 +1,6 @@
# pgxs/neon/Makefile
MODULE_big = neon
OBJS = \
$(WIN32RES) \
@@ -21,8 +22,7 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
control_plane_connector.o \
walsender_hooks.o \
$(LIBCOMMUNICATOR_PATH)/libcommunicator.a
walsender_hooks.o
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)

View File

@@ -1,13 +0,0 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
[lib]
crate-type = ["staticlib"]
[dependencies]
neon-shmem.workspace = true
[build-dependencies]
cbindgen.workspace = true

View File

@@ -1,8 +0,0 @@
This package will evolve into a "compute-pageserver communicator"
process and machinery. For now, it just provides wrappers on the
neon-shmem Rust crate, to allow using it in the C implementation of
the LFC.
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -1,22 +0,0 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {:?}", e),
},
|bindings| {
bindings.write_to_file("communicator_bindings.h");
},
);
Ok(())
}

View File

@@ -1,4 +0,0 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -1,240 +0,0 @@
//! Glue code to allow using the Rust shmem hash map implementation from C code
//!
//! For convience of adapting existing code, the interface provided somewhat resembles the dynahash
//! interface.
//!
//! NOTE: The caller is responsible for locking! The caller is expected to hold the PostgreSQL
//! LWLock, 'lfc_lock', while accessing the hash table, in shared or exclusive mode as appropriate.
use std::ffi::c_void;
use std::marker::PhantomData;
use neon_shmem::hash::entry::Entry;
use neon_shmem::hash::{HashMapAccess, HashMapInit};
use neon_shmem::shmem::ShmemHandle;
/// NB: This must match the definition of BufferTag in Postgres C headers. We could use bindgen to
/// generate this from the C headers, but prefer to not introduce dependency on bindgen for now.
///
/// Note that there are no padding bytes. If the corresponding C struct has padding bytes, the C C
/// code must clear them.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
/// Like with FileCacheKey, this must match the definition of FileCacheEntry in file_cache.c. We
/// don't look at the contents here though, it's sufficent that the size and alignment matches.
#[derive(Clone, Debug, Default)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
/// XXX: This could be just:
///
/// ```ignore
/// type FileCacheHashMapHandle = HashMapInit<'a, FileCacheKey, FileCacheEntry>
/// ```
///
/// but with that, cbindgen generates a broken typedef in the C header file which doesn't
/// compile. It apparently gets confused by the generics.
#[repr(transparent)]
pub struct FileCacheHashMapHandle<'a>(
pub *mut c_void,
PhantomData<HashMapInit<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapHandle<'a> {
fn from(x: Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>) -> Self {
FileCacheHashMapHandle(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> From<FileCacheHashMapHandle<'a>> for Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>> {
fn from(x: FileCacheHashMapHandle) -> Self {
unsafe { Box::from_raw(x.0.cast()) }
}
}
/// XXX: same for this
#[repr(transparent)]
pub struct FileCacheHashMapAccess<'a>(
pub *mut c_void,
PhantomData<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapAccess<'a> {
fn from(x: Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>) -> Self {
// Convert the Box into a raw mutable pointer to the HashMapAccess itself.
// This transfers ownership of the HashMapAccess (and its contained ShmemHandle)
// to the raw pointer. The C caller is now responsible for managing this memory.
FileCacheHashMapAccess(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> FileCacheHashMapAccess<'a> {
fn as_ref(self) -> &'a HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_ref().unwrap() }
}
fn as_mut(self) -> &'a mut HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_mut().unwrap() }
}
}
/// Initialize the shared memory area at postmaster startup. The returned handle is inherited
/// by all the backend processes across fork()
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_init<'a>(
initial_num_buckets: u32,
max_num_buckets: u32,
) -> FileCacheHashMapHandle<'a> {
let max_bytes = HashMapInit::<FileCacheKey, FileCacheEntry>::estimate_size(max_num_buckets);
let shmem_handle =
ShmemHandle::new("lfc mapping", 0, max_bytes).expect("shmem initialization failed");
let handle = HashMapInit::<FileCacheKey, FileCacheEntry>::init_in_shmem(
initial_num_buckets,
shmem_handle,
);
Box::new(handle).into()
}
/// Initialize the access to the shared memory area in a backend process.
///
/// XXX: I'm not sure if this actually gets called in each process, or if the returned struct
/// is also inherited across fork(). It currently works either way but if this did more
/// initialization that needed to be done after fork(), then it would matter.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_access<'a>(
handle: FileCacheHashMapHandle<'a>,
) -> FileCacheHashMapAccess<'a> {
let handle: Box<HashMapInit<'_, FileCacheKey, FileCacheEntry>> = handle.into();
Box::new(handle.attach_writer()).into()
}
/// Return the current number of buckets in the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_num_buckets<'a>(
map: FileCacheHashMapAccess<'static>,
) -> u32 {
let map = map.as_ref();
map.get_num_buckets().try_into().unwrap()
}
/// Look up the entry with given key and hash.
///
/// This is similar to dynahash's hash_search(... , HASH_FIND)
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_find<'a>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
hash: u64,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_with_hash(key, hash)
}
/// Look up the entry at given bucket position
///
/// This has no direct equivalent in the dynahash interface, but can be used to
/// iterate through all entries in the hash table.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_at_pos<'a>(
map: FileCacheHashMapAccess<'static>,
pos: u32,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_at_bucket(pos as usize).map(|(_k, v)| v)
}
/// Remove entry, given a pointer to the value.
///
/// This is equivalent to dynahash hash_search(entry->key, HASH_REMOVE), where 'entry'
/// is an entry you have previously looked up
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_remove_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *mut FileCacheEntry,
) {
let map = map.as_mut();
let pos = map.get_bucket_for_value(entry);
match map.entry_at_bucket(pos) {
Some(e) => {
e.remove();
}
None => {
// todo: shouldn't happen, panic?
}
}
}
/// Compute the hash for given key
///
/// This is equivalent to dynahash get_hash_value() function. We use Rust's default hasher
/// for calculating the hash though.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_hash_value<'a, 'b>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
) -> u64 {
map.as_ref().get_hash_value(key)
}
/// Insert a new entry to the hash table
///
/// This is equivalent to dynahash hash_search(..., HASH_ENTER).
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_enter<'a, 'b>(
map: FileCacheHashMapAccess,
key: &FileCacheKey,
hash: u64,
found: &mut bool,
) -> *mut FileCacheEntry {
match map.as_mut().entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => {
*found = true;
e.get_mut()
}
Entry::Vacant(e) => {
*found = false;
let initial_value = FileCacheEntry::default();
e.insert(initial_value).expect("TODO: hash table full")
}
}
}
/// Get the key for a given entry, which must be present in the hash table.
///
/// Dynahash requires the key to be part of the "value" struct, so you can always
/// access the key with something like `entry->key`. The Rust implementation however
/// stores the key separately. This function extracts the separately stored key.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_key_for_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *const FileCacheEntry,
) -> Option<&FileCacheKey> {
let map = map.as_ref();
let pos = map.get_bucket_for_value(entry);
map.get_at_bucket(pos as usize).map(|(k, _v)| k)
}
/// Remove all entries from the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_reset<'a, 'b>(map: FileCacheHashMapAccess) {
let map = map.as_mut();
let num_buckets = map.get_num_buckets();
for i in 0..num_buckets {
if let Some(e) = map.entry_at_bucket(i) {
e.remove();
}
}
}

View File

@@ -1 +0,0 @@
pub mod file_cache_hashmap;

View File

@@ -21,7 +21,7 @@
#include "access/xlog.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "common/file_utils.h"
#include "common/hashfn.h"
#include "pgstat.h"
#include "port/pg_iovec.h"
#include "postmaster/bgworker.h"
@@ -36,6 +36,7 @@
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
@@ -45,7 +46,6 @@
#include "hll.h"
#include "bitmap.h"
#include "file_cache.h"
#include "file_cache_rust_hash.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "neon_perf_counters.h"
@@ -64,7 +64,7 @@
*
* Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about
* its consistency.
*
*
* ## Holes
*
@@ -76,15 +76,13 @@
* fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't
* shrink, but the disk space it uses does.
*
* Each hole is tracked in a freelist. The freelist consists of two parts: a
* fixed-size array in shared memory, and a linked chain of on-disk
* blocks. When the in-memory array fills up, it's flushed to a new on-disk
* chunk. If the soft limit is raised again, we reuse the holes before
* extending the nominal size of the file.
*
* The in-memory freelist array is protected by 'lfc_lock', while the on-disk
* chain is protected by a separate 'lfc_freelist_lock'. Locking rule to
* avoid deadlocks: always acquire lfc_freelist_lock first, then lfc_lock.
* Each hole is tracked by a dummy FileCacheEntry, which are kept in the
* 'holes' linked list. They are entered into the chunk hash table, with a
* special key where the blockNumber is used to store the 'offset' of the
* hole, and all other fields are zero. Holes are never looked up in the hash
* table, we only enter them there to have a FileCacheEntry that we can keep
* in the linked list. If the soft limit is raised again, we reuse the holes
* before extending the nominal size of the file.
*/
/* Local file storage allocation chunk.
@@ -94,15 +92,13 @@
* 1Mb chunks can reduce hash map size to 320Mb.
* 2. Improve access locality, subsequent pages will be allocated together improving seqscan speed
*/
#define BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define BLOCKS_PER_CHUNK (1 << BLOCKS_PER_CHUNK_LOG)
#define MAX_BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define MAX_BLOCKS_PER_CHUNK (1 << MAX_BLOCKS_PER_CHUNK_LOG)
#define MB ((uint64)1024*1024)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> BLOCKS_PER_CHUNK_LOG))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (BLOCKS_PER_CHUNK-1))
#define INVALID_OFFSET (0xffffffff)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> lfc_chunk_size_log))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (lfc_blocks_per_chunk-1))
/*
* Blocks are read or written to LFC file outside LFC critical section.
@@ -123,18 +119,15 @@ typedef enum FileCacheBlockState
typedef struct FileCacheEntry
{
BufferTag key;
uint32 hash;
uint32 offset;
uint32 access_count;
dlist_node list_node; /* LRU list node */
uint32 state[(BLOCKS_PER_CHUNK * 2 + 31) / 32]; /* two bits per block */
dlist_node list_node; /* LRU/holes list node */
uint32 state[FLEXIBLE_ARRAY_MEMBER]; /* two bits per block */
} FileCacheEntry;
/* Todo: alignment must be the same too */
StaticAssertDecl(sizeof(FileCacheEntry) == sizeof(RustFileCacheEntry),
"Rust and C declarations of FileCacheEntry are incompatible");
StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
"Rust and C declarations of FileCacheKey are incompatible");
#define FILE_CACHE_ENRTY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
#define GET_STATE(entry, i) (((entry)->state[(i) / 16] >> ((i) % 16 * 2)) & 3)
#define SET_STATE(entry, i, new_state) (entry)->state[(i) / 16] = ((entry)->state[(i) / 16] & ~(3 << ((i) % 16 * 2))) | ((new_state) << ((i) % 16 * 2))
@@ -143,9 +136,6 @@ StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
#define MAX_PREWARM_WORKERS 8
#define FREELIST_ENTRIES_PER_CHUNK (BLOCKS_PER_CHUNK * BLCKSZ / sizeof(uint32) - 2)
typedef struct PrewarmWorkerState
{
uint32 prewarmed_pages;
@@ -171,6 +161,7 @@ typedef struct FileCacheControl
uint64 evicted_pages; /* number of evicted pages */
dlist_head lru; /* double linked list for LRU replacement
* algorithm */
dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
@@ -181,39 +172,23 @@ typedef struct FileCacheControl
bool prewarm_active;
bool prewarm_canceled;
dsm_handle prewarm_lfc_state_handle;
/*
* Free list. This is large enough to hold one chunks worth of entries.
*/
uint32 freelist_size;
uint32 freelist_head;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FileCacheControl;
typedef struct FreeListChunk
{
uint32 next;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FreeListChunk;
#define FILE_CACHE_STATE_MAGIC 0xfcfcfcfc
#define FILE_CACHE_STATE_BITMAP(fcs) ((uint8*)&(fcs)->chunks[(fcs)->n_chunks])
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * BLOCKS_PER_CHUNK)+7)/8)
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * lfc_blocks_per_chunk)+7)/8)
#define FILE_CACHE_STATE_SIZE(fcs) (sizeof(FileCacheState) + (fcs->n_chunks)*sizeof(BufferTag) + (((fcs->n_chunks) << fcs->chunk_size_log)+7)/8)
static FileCacheHashMapHandle lfc_hash_handle;
static FileCacheHashMapAccess lfc_hash;
static HTAB *lfc_hash;
static int lfc_desc = -1;
static LWLockId lfc_lock;
static LWLockId lfc_freelist_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_prewarm_limit;
static int lfc_prewarm_batch;
static int lfc_blocks_per_chunk_ro = BLOCKS_PER_CHUNK;
static int lfc_chunk_size_log = MAX_BLOCKS_PER_CHUNK_LOG;
static int lfc_blocks_per_chunk = MAX_BLOCKS_PER_CHUNK;
static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
@@ -230,11 +205,6 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
static bool freelist_push(uint32 offset);
static bool freelist_prepare_pop(void);
static uint32 freelist_pop(void);
static bool freelist_is_empty(void);
/*
* Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled.
@@ -262,9 +232,15 @@ lfc_switch_off(void)
if (LFC_ENABLED())
{
/* Invalidate hash */
file_cache_hash_reset(lfc_hash);
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
/* Invalidate hash */
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search_with_hash_value(lfc_hash, &entry->key, entry->hash, HASH_REMOVE, NULL);
}
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->pinned = 0;
@@ -272,9 +248,7 @@ lfc_switch_off(void)
lfc_ctl->used_pages = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/*
* We need to use unlink to to avoid races in LFC write, because it is not
@@ -343,8 +317,8 @@ lfc_ensure_opened(void)
static void
lfc_shmem_startup(void)
{
size_t size;
bool found;
static HASHCTL info;
if (prev_shmem_startup_hook)
{
@@ -353,29 +327,27 @@ lfc_shmem_startup(void)
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
size = sizeof(FileCacheControl);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", size, &found);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
{
int fd;
uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock");
lfc_freelist_lock = (LWLockId) GetNamedLWLockTranche("lfc_freelist_lock");
info.keysize = sizeof(BufferTag);
info.entrysize = FILE_CACHE_ENRTY_SIZE;
/*
* n_chunks+1 because we add new element to hash table before eviction
* of victim
*/
lfc_hash_handle = file_cache_hash_shmem_init(n_chunks + 1, n_chunks + 1);
memset(lfc_ctl, 0, offsetof(FileCacheControl, free_pages));
lfc_hash = ShmemInitHash("lfc_hash",
n_chunks + 1, n_chunks + 1,
&info,
HASH_ELEM | HASH_BLOBS);
memset(lfc_ctl, 0, sizeof(FileCacheControl));
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_size = FREELIST_ENTRIES_PER_CHUNK;
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/* Initialize hyper-log-log structure for estimating working set size */
initSHLL(&lfc_ctl->wss_estimation);
@@ -399,25 +371,18 @@ lfc_shmem_startup(void)
}
LWLockRelease(AddinShmemInitLock);
lfc_hash = file_cache_hash_shmem_access(lfc_hash_handle);
}
static void
lfc_shmem_request(void)
{
size_t size;
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
size = sizeof(FileCacheControl);
RequestAddinShmemSpace(size);
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
RequestNamedLWLockTranche("lfc_freelist_lock", 2);
}
static bool
@@ -433,6 +398,24 @@ is_normal_backend(void)
return lfc_ctl && MyProc && UsedShmemSegAddr && !IsParallelWorker();
}
static bool
lfc_check_chunk_size(int *newval, void **extra, GucSource source)
{
if (*newval & (*newval - 1))
{
elog(ERROR, "LFC chunk size should be power of two");
return false;
}
return true;
}
static void
lfc_change_chunk_size(int newval, void* extra)
{
lfc_chunk_size_log = pg_ceil_log2_32(newval);
}
static bool
lfc_check_limit_hook(int *newval, void **extra, GucSource source)
{
@@ -452,14 +435,12 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ctl || !is_normal_backend())
return;
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
/* Open LFC file only if LFC was enabled or we are going to reenable it */
if (newval == 0 && !LFC_ENABLED())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
/* File should be reopened if LFC is reenabled */
lfc_close_file();
return;
@@ -468,7 +449,6 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ensure_opened())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
@@ -484,30 +464,35 @@ lfc_change_limit_hook(int newval, void *extra)
* returning their space to file system
*/
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
FileCacheEntry *hole;
uint32 offset = victim->offset;
uint32 hash;
bool found;
BufferTag holetag;
CriticalAssert(victim->access_count == 0);
#ifdef FALLOC_FL_PUNCH_HOLE
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * BLOCKS_PER_CHUNK * BLCKSZ, BLOCKS_PER_CHUNK * BLCKSZ) < 0)
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * lfc_blocks_per_chunk * BLCKSZ, lfc_blocks_per_chunk * BLCKSZ) < 0)
neon_log(LOG, "Failed to punch hole in file: %m");
#endif
/* We remove the entry, and enter a hole to the freelist */
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
/* We remove the old entry, and re-enter a hole to the hash table */
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
lfc_ctl->evicted_pages += is_page_cached;
}
file_cache_hash_remove_entry(lfc_hash, victim);
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
if (!freelist_push(offset))
{
/* freelist_push already logged the error */
lfc_switch_off();
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
memset(&holetag, 0, sizeof(holetag));
holetag.blockNum = offset;
hash = get_hash_value(lfc_hash, &holetag);
hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found);
hole->hash = hash;
hole->offset = offset;
hole->access_count = 0;
CriticalAssert(!found);
dlist_push_tail(&lfc_ctl->holes, &hole->list_node);
lfc_ctl->used -= 1;
}
@@ -519,7 +504,6 @@ lfc_change_limit_hook(int newval, void *extra)
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
}
void
@@ -595,14 +579,14 @@ lfc_init(void)
DefineCustomIntVariable("neon.file_cache_chunk_size",
"LFC chunk size in blocks (should be power of two)",
NULL,
&lfc_blocks_per_chunk_ro,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
PGC_INTERNAL,
&lfc_blocks_per_chunk,
MAX_BLOCKS_PER_CHUNK,
1,
MAX_BLOCKS_PER_CHUNK,
PGC_POSTMASTER,
GUC_UNIT_BLOCKS,
NULL,
NULL,
lfc_check_chunk_size,
lfc_change_chunk_size,
NULL);
DefineCustomIntVariable("neon.file_cache_prewarm_limit",
@@ -665,19 +649,19 @@ lfc_get_state(size_t max_entries)
fcs = (FileCacheState*)palloc0(state_size);
SET_VARSIZE(fcs, state_size);
fcs->magic = FILE_CACHE_STATE_MAGIC;
fcs->chunk_size_log = BLOCKS_PER_CHUNK_LOG;
fcs->chunk_size_log = lfc_chunk_size_log;
fcs->n_chunks = n_entries;
bitmap = FILE_CACHE_STATE_BITMAP(fcs);
dlist_reverse_foreach(iter, &lfc_ctl->lru)
{
FileCacheEntry *entry = dlist_container(FileCacheEntry, list_node, iter.cur);
fcs->chunks[i] = *file_cache_hash_get_key_for_entry(lfc_hash, entry);
for (int j = 0; j < BLOCKS_PER_CHUNK; j++)
fcs->chunks[i] = entry->key;
for (int j = 0; j < lfc_blocks_per_chunk; j++)
{
if (GET_STATE(entry, j) != UNAVAILABLE)
{
BITMAP_SET(bitmap, i*BLOCKS_PER_CHUNK + j);
BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j);
n_pages += 1;
}
}
@@ -686,7 +670,7 @@ lfc_get_state(size_t max_entries)
}
Assert(i == n_entries);
fcs->n_pages = n_pages;
Assert(pg_popcount((char*)bitmap, ((n_entries << BLOCKS_PER_CHUNK_LOG) + 7)/8) == n_pages);
Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages);
elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages);
}
@@ -742,7 +726,7 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers)
}
fcs_chunk_size_log = fcs->chunk_size_log;
if (fcs_chunk_size_log > BLOCKS_PER_CHUNK_LOG)
if (fcs_chunk_size_log > MAX_BLOCKS_PER_CHUNK_LOG)
{
elog(ERROR, "LFC: Invalid chunk size log: %u", fcs->chunk_size_log);
}
@@ -961,7 +945,7 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
{
BufferTag tag;
FileCacheEntry *entry;
uint64 hash;
uint32 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
@@ -974,14 +958,14 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (LFC_ENABLED())
{
for (BlockNumber blkno = 0; blkno < nblocks; blkno += BLOCKS_PER_CHUNK)
for (BlockNumber blkno = 0; blkno < nblocks; blkno += lfc_blocks_per_chunk)
{
tag.blockNum = blkno;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
hash = get_hash_value(lfc_hash, &tag);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
if (entry != NULL)
{
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
@@ -1006,7 +990,7 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
FileCacheEntry *entry;
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
bool found = false;
uint64 hash;
uint32 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
@@ -1016,12 +1000,12 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
tag.blockNum = blkno - chunk_offs;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
if (LFC_ENABLED())
{
entry = file_cache_hash_find(lfc_hash, &tag, hash);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
found = entry != NULL && GET_STATE(entry, chunk_offs) != UNAVAILABLE;
}
LWLockRelease(lfc_lock);
@@ -1040,7 +1024,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
uint32 chunk_offs;
int found = 0;
uint64 hash;
uint32 hash;
int i = 0;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
@@ -1053,7 +1037,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
@@ -1064,12 +1048,12 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
while (true)
{
int this_chunk = Min(nblocks - i, BLOCKS_PER_CHUNK - chunk_offs);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
int this_chunk = Min(nblocks - i, lfc_blocks_per_chunk - chunk_offs);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
if (entry != NULL)
{
for (; chunk_offs < BLOCKS_PER_CHUNK && i < nblocks; chunk_offs++, i++)
for (; chunk_offs < lfc_blocks_per_chunk && i < nblocks; chunk_offs++, i++)
{
if (GET_STATE(entry, chunk_offs) != UNAVAILABLE)
{
@@ -1095,7 +1079,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
*/
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno + i);
tag.blockNum = (blkno + i) - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
}
LWLockRelease(lfc_lock);
@@ -1144,7 +1128,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
int blocks_read = 0;
@@ -1170,9 +1154,9 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
while (nblocks > 0)
{
struct iovec iov[PG_IOV_MAX];
uint8 chunk_mask[BLOCKS_PER_CHUNK / 8] = {0};
uint8 chunk_mask[MAX_BLOCKS_PER_CHUNK / 8] = {0};
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int iteration_hits = 0;
int iteration_misses = 0;
uint64 io_time_us = 0;
@@ -1222,7 +1206,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
Assert(iov_last_used - first_block_in_chunk_read >= n_blocks_to_read);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
@@ -1235,13 +1219,13 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
return blocks_read;
}
entry = file_cache_hash_find(lfc_hash, &tag, hash);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (entry == NULL)
@@ -1312,7 +1296,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
if (iteration_hits != 0)
{
/* chunk offset (# of pages) into the LFC file */
off_t first_read_offset = (off_t) entry_offset * BLOCKS_PER_CHUNK;
off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk;
int nwrite = iov_last_used - first_block_in_chunk_read;
/* offset of first IOV */
first_read_offset += chunk_offs + first_block_in_chunk_read;
@@ -1389,14 +1373,14 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* Returns false if there are no unpinned entries and chunk can not be added.
*/
static bool
lfc_init_new_entry(FileCacheEntry *entry)
lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
{
/*-----------
* If the chunk wasn't already in the LFC then we have these
* options, in order of preference:
*
* Unless there is no space available, we can:
* 1. Use an entry from the freelist, and
* 1. Use an entry from the `holes` list, and
* 2. Create a new entry.
* We can always, regardless of space in the LFC:
* 3. evict an entry from LRU, and
@@ -1404,10 +1388,17 @@ lfc_init_new_entry(FileCacheEntry *entry)
*/
if (lfc_ctl->used < lfc_ctl->limit)
{
if (!freelist_is_empty())
if (!dlist_is_empty(&lfc_ctl->holes))
{
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
uint32 offset = freelist_pop();
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key,
hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1436,7 +1427,7 @@ lfc_init_new_entry(FileCacheEntry *entry)
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->lru));
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
@@ -1445,21 +1436,24 @@ lfc_init_new_entry(FileCacheEntry *entry)
CriticalAssert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
file_cache_hash_remove_entry(lfc_hash, victim);
hash_search_with_hash_value(lfc_hash, &victim->key,
victim->hash, HASH_REMOVE, NULL);
neon_log(DEBUG2, "Swap file cache page");
}
else
{
/* Can't add this chunk - we don't have the space for it */
file_cache_hash_remove_entry(lfc_hash, entry);
hash_search_with_hash_value(lfc_hash, &entry->key, hash,
HASH_REMOVE, NULL);
lfc_ctl->prewarm_canceled = true; /* cancel prewarm if LFC limit is reached */
return false;
}
entry->access_count = 1;
entry->hash = hash;
lfc_ctl->pinned += 1;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
SET_STATE(entry, i, UNAVAILABLE);
return true;
@@ -1496,7 +1490,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
instr_time io_start, io_end;
@@ -1515,10 +1509,9 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1527,9 +1520,6 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
if (!freelist_prepare_pop())
goto retry;
lwlsn = neon_get_lwlsn(rinfo, forknum, blkno);
if (lwlsn > lsn)
@@ -1540,12 +1530,12 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found)
{
@@ -1567,7 +1557,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry))
if (!lfc_init_new_entry(entry, hash))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1588,7 +1578,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwrite(lfc_desc, buffer, BLCKSZ,
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1650,7 +1640,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
int buf_offset = 0;
@@ -1663,7 +1653,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1673,9 +1662,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
generation = lfc_ctl->generation;
if (!freelist_prepare_pop())
goto retry;
/*
* For every chunk that has blocks we're interested in, we
* 1. get the chunk header
@@ -1689,7 +1675,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
{
struct iovec iov[PG_IOV_MAX];
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
instr_time io_start, io_end;
ConditionVariable* cv;
@@ -1702,16 +1688,16 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found)
@@ -1728,7 +1714,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry))
if (!lfc_init_new_entry(entry, hash))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1777,7 +1763,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwritev(lfc_desc, iov, blocks_in_chunk,
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1837,140 +1823,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/**** freelist management ****/
/*
* Prerequisites:
* - The caller is holding 'lfc_lock'. XXX
*/
static bool
freelist_prepare_pop(void)
{
/*
* If the in-memory freelist is empty, but there are more blocks available, load them.
*
* TODO: if there
*/
if (lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET)
{
uint32 freelist_head;
FreeListChunk *freelist_chunk;
size_t bytes_read;
LWLockRelease(lfc_lock);
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
if (!(lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET))
{
/* someone else did the work for us while we were not holding the lock */
LWLockRelease(lfc_freelist_lock);
return false;
}
freelist_head = lfc_ctl->freelist_head;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
bytes_read = 0;
while (bytes_read < BLOCKS_PER_CHUNK * BLCKSZ)
{
ssize_t rc;
rc = pread(lfc_desc, freelist_chunk, BLOCKS_PER_CHUNK * BLCKSZ - bytes_read, (off_t) freelist_head * BLOCKS_PER_CHUNK * BLCKSZ + bytes_read);
if (rc < 0)
{
lfc_disable("read freelist page");
return false;
}
bytes_read += rc;
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_generation != lfc_ctl->generation)
{
LWLockRelease(lfc_lock);
return false;
}
Assert(lfc_ctl->freelist_head == freelist_head);
Assert(lfc_ctl->num_free_pages == 0);
lfc_ctl->freelist_head = freelist_chunk->next;
lfc_ctl->num_free_pages = freelist_chunk->num_free_pages;
memcpy(lfc_ctl->free_pages, freelist_chunk->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
pfree(freelist_chunk);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return false;
}
return true;
}
/*
* Prerequisites:
* - The caller is holding 'lfc_lock' and 'lfc_freelist_lock'.
*
* Returns 'false' on error.
*/
static bool
freelist_push(uint32 offset)
{
Assert(lfc_ctl->freelist_size == FREELIST_ENTRIES_PER_CHUNK);
if (lfc_ctl->num_free_pages == lfc_ctl->freelist_size)
{
FreeListChunk *freelist_chunk;
struct iovec iov;
ssize_t rc;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
/* write the existing entries to the chunk on disk */
freelist_chunk->next = lfc_ctl->freelist_head;
freelist_chunk->num_free_pages = lfc_ctl->num_free_pages;
memcpy(freelist_chunk->free_pages, lfc_ctl->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
/* Use the passed-in offset to hold the freelist chunk itself */
iov.iov_base = freelist_chunk;
iov.iov_len = BLOCKS_PER_CHUNK * BLCKSZ;
rc = pg_pwritev_with_retry(lfc_desc, &iov, 1, (off_t) offset * BLOCKS_PER_CHUNK * BLCKSZ);
pfree(freelist_chunk);
if (rc < 0)
return false;
lfc_ctl->freelist_head = offset;
lfc_ctl->num_free_pages = 0;
}
else
{
lfc_ctl->free_pages[lfc_ctl->num_free_pages] = offset;
lfc_ctl->num_free_pages++;
}
return true;
}
static uint32
freelist_pop(void)
{
uint32 result;
/* The caller should've checked that the list is not empty */
Assert(lfc_ctl->num_free_pages > 0);
result = lfc_ctl->free_pages[lfc_ctl->num_free_pages - 1];
lfc_ctl->num_free_pages--;
return result;
}
static bool
freelist_is_empty(void)
{
return lfc_ctl->num_free_pages == 0;
}
typedef struct
{
TupleDesc tupdesc;
@@ -2067,7 +1919,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
break;
case 8:
key = "file_cache_chunk_size_pages";
value = BLOCKS_PER_CHUNK;
value = lfc_blocks_per_chunk;
break;
case 9:
key = "file_cache_chunks_pinned";
@@ -2138,6 +1990,7 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
@@ -2193,16 +2046,15 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (LFC_ENABLED())
{
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
}
@@ -2224,28 +2076,25 @@ local_cache_pages(PG_FUNCTION_ARGS)
* in the fctx->record structure.
*/
uint32 n = 0;
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
const BufferTag *key = file_cache_hash_get_key_for_entry(lfc_hash, entry);
if (GET_STATE(entry, i) == AVAILABLE)
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
fctx->record[n].pageoffs = entry->offset * BLOCKS_PER_CHUNK + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(*key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].forknum = key->forkNum;
fctx->record[n].blocknum = key->blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}

View File

@@ -89,6 +89,7 @@ tokio-postgres = { workspace = true, optional = true }
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
toml.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true

View File

@@ -18,11 +18,6 @@ pub(super) async fn authenticate(
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");

View File

@@ -6,10 +6,9 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
@@ -98,15 +97,11 @@ impl ConsoleRedirectBackend {
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(
ConsoleRedirectNodeInfo,
ComputeUserInfo,
Option<Vec<IpPattern>>,
)> {
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(|(node_info, user_info, ip_allowlist)| {
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
.map(|(node_info, auth_info, user_info)| {
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
})
}
}
@@ -121,10 +116,6 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
@@ -132,7 +123,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
// registering waiter can fail if we get unlucky with rng.
@@ -192,10 +183,24 @@ async fn authenticate(
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
let ssl_mode = if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
SslMode::Require
} else {
SslMode::Disable
};
let conn_info = compute::ConnectInfo {
host: db_info.host.into(),
port: db_info.port,
ssl_mode,
host_addr: None,
};
let auth_info =
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
let user: RoleName = db_info.user.into();
let user_info = ComputeUserInfo {
@@ -209,26 +214,12 @@ async fn authenticate(
ctx.set_project(db_info.aux.clone());
info!("woken up a compute node");
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok((
NodeInfo {
config,
conn_info,
aux: db_info.aux,
},
auth_info,
user_info,
db_info.allowed_ips,
))
}

View File

@@ -1,11 +1,12 @@
use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use postgres_client::config::SslMode;
use tokio::sync::Semaphore;
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnCfg;
use crate::compute::ConnectInfo;
use crate::compute_ctl::ComputeCtlApi;
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
@@ -29,7 +30,12 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
conn_info: ConnectInfo {
host_addr: Some(postgres_addr.ip()),
host: postgres_addr.ip().to_string().into(),
port: postgres_addr.port(),
ssl_mode: SslMode::Disable,
},
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),

View File

@@ -168,8 +168,6 @@ impl ComputeUserInfo {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
@@ -419,13 +417,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
#[cfg(test)]

View File

@@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange(
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;

View File

@@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -237,7 +237,6 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -8,14 +8,15 @@ use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use anyhow::{bail, anyhow};
use arc_swap::ArcSwapOption;
use futures::future::Either;
use remote_storage::RemoteStorageConfig;
use serde::Deserialize;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info, warn};
use tracing::{Instrument, info};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
@@ -39,7 +40,7 @@ use crate::serverless::cancel_set::CancelSet;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
use crate::{auth, control_plane, http, pglb, serverless, usage_metrics};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -59,6 +60,262 @@ enum AuthBackendType {
Postgres,
}
#[derive(Deserialize)]
struct Root {
#[serde(flatten)]
legacy: LegacyModes,
introspection: Introspection,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum LegacyModes {
Proxy {
pglb: Pglb,
neonkeeper: NeonKeeper,
http: Option<Http>,
pg_sni_router: Option<PgSniRouter>,
},
AuthBroker {
neonkeeper: NeonKeeper,
http: Http,
},
ConsoleRedirect {
console_redirect: ConsoleRedirect,
},
}
#[derive(Deserialize)]
struct Pglb {
listener: Listener,
}
#[derive(Deserialize)]
struct Listener {
/// address to bind to
addr: SocketAddr,
/// which header should we expect to see on this socket
/// from the load balancer
header: Option<ProxyHeader>,
/// certificates used for TLS.
/// first cert is the default.
/// TLS not used if no certs provided.
certs: Vec<KeyPair>,
/// Timeout to use for TLS handshake
timeout: Option<Duration>,
}
#[derive(Deserialize)]
enum ProxyHeader {
/// Accept the PROXY! protocol V2.
ProxyProtocolV2(ProxyProtocolV2Kind),
}
#[derive(Deserialize)]
enum ProxyProtocolV2Kind {
/// Expect AWS TLVs in the header.
Aws,
/// Expect Azure TLVs in the header.
Azure,
}
#[derive(Deserialize)]
struct KeyPair {
key: PathBuf,
cert: PathBuf,
}
#[derive(Deserialize)]
/// The service that authenticates all incoming connection attempts,
/// provides monitoring and also wakes computes.
struct NeonKeeper {
cplane: ControlPlaneBackend,
redis: Option<Redis>,
auth: Vec<AuthMechanism>,
/// map of endpoint->computeinfo
compute: Cache,
/// cache for GetEndpointAccessControls.
project_info_cache: config::ProjectInfoCacheOptions,
/// cache for all valid endpoints
endpoint_cache_config: config::EndpointCacheConfig,
request_log_export: Option<RequestLogExport>,
data_transfer_export: Option<DataTransferExport>,
}
#[derive(Deserialize)]
struct Redis {
/// Cancellation channel size (max queue size for redis kv client)
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
cancellation_batch_size: usize,
auth: RedisAuthentication,
}
#[derive(Deserialize)]
enum RedisAuthentication {
/// i don't remember what this stands for.
/// IAM roles for service accounts?
Irsa {
host: String,
port: u16,
cluster_name: Option<String>,
user_id: Option<String>,
aws_region: String,
},
Basic {
url: url::Url,
},
}
#[derive(Deserialize)]
struct PgSniRouter {
/// The listener to use to proxy connections to compute,
/// assuming the compute does not support TLS.
listener: Listener,
/// The listener to use to proxy connections to compute,
/// assuming the compute requires TLS.
listener_tls: Listener,
/// append this domain zone to the SNI hostname to get the destination address
dest: String,
}
#[derive(Deserialize)]
/// `psql -h pg.neon.tech`.
struct ConsoleRedirect {
/// Connection requests from clients.
listener: Listener,
/// Messages from control plane to accept the connection.
cplane: Listener,
/// The base url to use for redirects.
console: url::Url,
timeout: Duration,
}
#[derive(Deserialize)]
enum ControlPlaneBackend {
/// Use the HTTP API to access the control plane.
Http(url::Url),
/// Stub the control plane with a postgres instance.
#[cfg(feature = "testing")]
PostgresMock(url::Url),
}
#[derive(Deserialize)]
struct Http {
listener: Listener,
sql_over_http: SqlOverHttp,
// todo: move into Pglb.
websockets: Option<Websockets>,
}
#[derive(Deserialize)]
struct SqlOverHttp {
pool_max_conns_per_endpoint: usize,
pool_max_total_conns: usize,
pool_idle_timeout: Duration,
pool_gc_epoch: Duration,
pool_shards: usize,
client_conn_threshold: u64,
cancel_set_shards: usize,
timeout: Duration,
max_request_size_bytes: usize,
max_response_size_bytes: usize,
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
enum AuthMechanism {
Sasl {
/// timeout for SASL handshake
timeout: Duration,
},
CleartextPassword {
/// number of threads for the thread pool
threads: usize,
},
// add something about the jwks cache i guess.
Jwt {},
}
#[derive(Deserialize)]
struct Websockets {
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
/// The HTTP API used for internal monitoring.
struct Introspection {
listener: Listener,
}
#[derive(Deserialize)]
enum RequestLogExport {
Parquet {
location: RemoteStorageConfig,
disconnect: RemoteStorageConfig,
/// The region identifier to tag the entries with.
region: String,
/// How many rows to include in a row group
row_group_size: usize,
/// How large each column page should be in bytes
page_size: usize,
/// How large the total parquet file should be in bytes
size: i64,
/// How long to wait before forcing a file upload
maximum_duration: tokio::time::Duration,
// /// What level of compression to use
// compression: Compression,
},
}
#[derive(Deserialize)]
enum Cache {
/// Expire by LRU or by idle.
/// Note: "live" in "time-to-live" actually means idle here.
LruTtl {
/// Max number of entries.
size: usize,
/// Entry's time-to-live.
ttl: Duration,
},
}
#[derive(Deserialize)]
struct DataTransferExport {
/// http endpoint to receive periodic metric updates
endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
interval: Option<String>,
/// interval for backup metric collection
backup_interval: std::time::Duration,
/// remote storage configuration for backup metric collection
/// Encoded as toml (same format as pageservers), eg
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
backup_remote_storage: Option<RemoteStorageConfig>,
/// chunk size for backup metric collection
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
backup_chunk_size: usize,
}
/// Neon proxy/router
#[derive(Parser)]
#[command(version = GIT_VERSION, about)]
@@ -120,12 +377,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -152,40 +403,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host)
redis_auth_type: Option<String>,
/// redis host for irsa authentication
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
/// redis port for irsa authentication
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
/// redis cluster name for irsa authentication
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
/// redis user_id for irsa authentication
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
/// aws region for irsa authentication
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -197,6 +439,12 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -209,6 +457,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -319,208 +568,120 @@ pub async fn run() -> anyhow::Result<()> {
}
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
match auth_backend {
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
info!("Starting mgmt on {}", args.mgmt);
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
let proxy_listener = if args.is_auth_broker {
None
} else {
info!("Starting proxy on {}", args.proxy);
Some(TcpListener::bind(args.proxy).await?)
};
let sni_router_listeners = {
let args = &args.pg_sni_router;
if args.dest.is_some() {
ensure!(
args.tls_key.is_some(),
"sni-router-tls-key must be provided"
);
ensure!(
args.tls_cert.is_some(),
"sni-router-tls-cert must be provided"
);
info!(
"Starting pg-sni-router on {} and {}",
args.listen, args.listen_tls
);
Some((
TcpListener::bind(args.listen).await?,
TcpListener::bind(args.listen_tls).await?,
))
} else {
None
}
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
let cancellation_handler = Arc::new(CancellationHandler::new(
&config.connect_to_compute,
Some(tx_cancel),
));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
64,
));
let config: Root = toml::from_str(&tokio::fs::read_to_string("proxy.toml").await?)?;
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
}
}
}
// spawn pg-sni-router mode.
if let Some((listen, listen_tls)) = sni_router_listeners {
let args = args.pg_sni_router;
let dest = args.dest.expect("already asserted it is set");
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
));
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: crate::metrics::Metrics::get(),
},
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
let cancellation_token = CancellationToken::new();
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
match config.legacy {
LegacyModes::Proxy {
pglb,
neonkeeper,
http,
pg_sni_router,
} => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
info!("Starting proxy on {}", pglb.listener.addr);
let proxy_listener = TcpListener::bind(pglb.listener.addr).await?;
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
if let Some(http) = http {
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
http_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
};
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
if let Some(mut redis_kv_client) = redis_kv_client {
if let Some(sni_router) = pg_sni_router {
let listen = TcpListener::bind(sni_router.listener.addr).await?;
let listen_tls = TcpListener::bind(sni_router.listener_tls.addr).await?;
let [KeyPair { key, cert }] = sni_router
.listener
.certs
.try_into()
.map_err(|_| anyhow!("only 1 keypair is supported for pg-sni-router"))?;
let tls_config = super::pg_sni_router::parse_tls(&key, &cert)?;
let dest = Arc::new(sni_router.dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
@@ -537,18 +698,139 @@ pub async fn run() -> anyhow::Result<()> {
// so let's wait forever instead.
std::future::pending().await
});
}
if let Some(regional_redis_client) = regional_redis_client {
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::AuthBroker { neonkeeper, http } => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::ConsoleRedirect { console_redirect } => {
info!("Starting proxy on {}", console_redirect.listener.addr);
let proxy_listener = TcpListener::bind(console_redirect.listener.addr).await?;
info!("Starting mgmt on {}", console_redirect.listener.addr);
let mgmt_listener = TcpListener::bind(console_redirect.listener.addr).await?;
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
}
}
// Check that we can bind to address before further initialization
info!("Starting http on {}", config.introspection.listener.addr);
let http_listener = TcpListener::bind(config.introspection.listener.addr)
.await?
.into_std()?;
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
let cancellation_handler = Arc::new(CancellationHandler::new(
&config.connect_to_compute,
Some(tx_cancel),
));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: crate::metrics::Metrics::get(),
},
));
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
let maintenance = loop {
@@ -673,7 +955,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -833,58 +1114,45 @@ fn build_auth_backend(
}
}
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
port,
elasticache::CredentialsProvider::new(
args.aws_region.clone(),
args.redis_cluster_name.clone(),
args.redis_user_id.clone(),
)
.await,
),
),
(None, None) => {
// todo: upgrade to error?
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
async fn configure_redis(auth: RedisAuthentication) -> ConnectionWithCredentialsProvider {
match auth {
RedisAuthentication::Irsa {
host,
port,
cluster_name,
user_id,
aws_region,
} => ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache::CredentialsProvider::new(aws_region, cluster_name, user_id).await,
),
RedisAuthentication::Basic { url } => {
ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())
}
}
}
None => None,
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
None => None,
};
Ok((regional_redis_client, redis_notifications_client))
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
#[cfg(test)]

View File

@@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
@@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
info!("invalidating endpoint access for `{endpoint_id}`");
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self

View File

@@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type IpSubnetKey = IpNet;
@@ -497,10 +496,8 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
compute_config,
&self.hostname,
)
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;

View File

@@ -1,21 +1,24 @@
mod tls;
use std::fmt::Debug;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use std::net::{IpAddr, SocketAddr};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
@@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -38,10 +40,7 @@ pub(crate) enum ConnectionError {
Postgres(#[from] postgres_client::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
TlsError(#[from] TlsError),
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
@@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError {
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
_ => COULD_NOT_CONNECT.to_owned(),
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
}
}
}
@@ -85,7 +84,6 @@ impl ReportableError for ConnectionError {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -96,34 +94,85 @@ impl ReportableError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
pub enum Auth {
/// Only used during console-redirect.
Password(Vec<u8>),
/// Used by sql-over-http, ws, tcp.
Scram(Box<ScramKeys>),
}
/// A config for authenticating to the compute node.
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
auth: Option<Auth>,
server_params: StartupMessageParams,
/// Console redirect sets user and database, we shouldn't re-use those from the params.
skip_db_user: bool,
}
/// Contains only the data needed to establish a secure connection to compute.
#[derive(Clone)]
pub struct ConnectInfo {
pub host_addr: Option<IpAddr>,
pub host: Host,
pub port: u16,
pub ssl_mode: SslMode,
}
/// Creation and initialization routines.
impl ConnCfg {
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}
/// Reuse password or auth keys from the other config.
pub(crate) fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
impl AuthInfo {
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
let mut server_params = StartupMessageParams::default();
server_params.insert("database", db);
server_params.insert("user", user);
Self {
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
server_params,
skip_db_user: true,
}
}
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,
}
}
}
impl ConnectInfo {
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
config.ssl_mode(self.ssl_mode);
if let Some(host_addr) = self.host_addr {
config.set_host_addr(host_addr);
}
config
}
}
impl AuthInfo {
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
match &self.auth {
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
Some(Auth::Password(pw)) => config.password(pw),
None => &mut config,
};
for (k, v) in self.server_params.iter() {
config.set_param(k, v);
}
config
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(
@@ -132,27 +181,26 @@ impl ConnCfg {
arbitrary_params: bool,
) {
if !arbitrary_params {
self.set_param("client_encoding", "UTF8");
self.server_params.insert("client_encoding", "UTF8");
}
for (k, v) in params.iter() {
match k {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
"user" if self.user_is_set() => {}
"database" if self.db_is_set() => {}
"user" | "database" if self.skip_db_user => {}
"options" => {
if let Some(options) = filtered_options(v) {
self.set_param(k, &options);
self.server_params.insert(k, &options);
}
}
"user" | "database" | "application_name" | "replication" => {
self.set_param(k, v);
self.server_params.insert(k, v);
}
// if we allow arbitrary params, then we forward them through.
// this is a flag for a period of backwards compatibility
k if arbitrary_params => {
self.set_param(k, v);
self.server_params.insert(k, v);
}
_ => {}
}
@@ -160,25 +208,13 @@ impl ConnCfg {
}
}
impl std::ops::Deref for ConnCfg {
type Target = postgres_client::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
use postgres_client::config::Host;
impl ConnectInfo {
/// Establish a raw TCP+TLS connection to the compute node.
async fn connect_raw(
&self,
config: &ComputeConfig,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
// wrap TcpStream::connect with timeout
let connect_with_timeout = |addrs| {
@@ -208,34 +244,32 @@ impl ConnCfg {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let port = self.0.get_port();
let host = self.0.get_host();
let port = self.port;
let host = &*self.host;
let host = match host {
Host::Tcp(host) => host.as_str(),
};
let addrs = match self.0.get_host_addr() {
let addrs = match self.host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port)).await?.collect(),
};
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(err)
Err(TlsError::Connection(err))
}
}
}
}
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
pub(crate) stream:
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// PostgreSQL connection parameters.
pub(crate) params: std::collections::HashMap<String, String>,
/// Query cancellation token.
@@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection {
_guage: NumDbConnectionsGuard<'static>,
}
impl ConnCfg {
impl ConnectInfo {
/// Connect to a corresponding compute node.
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
auth: &AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
// we setup SSL early in `ConnectInfo::connect_raw`.
tmp_config.ssl_mode(SslMode::Disable);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = self.0.connect_raw(stream, tls).await?;
let (socket_addr, stream) = self.connect_raw(config).await?;
let connection = tmp_config.connect_raw(stream, NoTls).await?;
drop(pause);
let RawConnection {
@@ -282,13 +311,14 @@ impl ConnCfg {
tracing::Span::current().record("pid", tracing::field::display(process_id));
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
let stream = stream.into_inner();
let MaybeTlsStream::Raw(stream) = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.0.get_ssl_mode(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
@@ -299,11 +329,11 @@ impl ConnCfg {
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.0.get_ssl_mode(),
ssl_mode: self.ssl_mode,
process_id,
secret_key,
},
host.to_string(),
self.host.to_string(),
user_info,
);

63
proxy/src/compute/tls.rs Normal file
View File

@@ -0,0 +1,63 @@
use futures::FutureExt;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
Required,
}
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
}
}
}
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: &T,
host: &str,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
T: MakeTlsConnect<
S,
Error = InvalidDnsNameError,
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}

View File

@@ -22,7 +22,6 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
@@ -70,7 +69,7 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
@@ -206,7 +205,7 @@ impl FromStr for CacheOptions {
}
/// Helper for cmdline cache options parsing.
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct ProjectInfoCacheOptions {
/// Max number of entries.
pub size: usize,

View File

@@ -90,12 +90,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -210,20 +205,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
let (node_info, mut auth_info, user_info) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
auth_info.set_startup_params(&params, true);
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
params_compat: true,
params: &params,
auth: auth_info,
locks: &config.connect_compute_locks,
},
&node_info,

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: &'static str,
region: String,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: value.region,
region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config).await
worker_inner(storage, rx, parquet_config, &region).await
}
}
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(row) = rx.next().await {
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
region: String::new(),
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -261,24 +261,18 @@ impl NeonControlPlaneClient {
Some(_) => SslMode::Require,
None => SslMode::Disable,
};
let host_name = match body.server_name {
Some(host) => host,
None => host.to_owned(),
let host = match body.server_name {
Some(host) => host.into(),
None => host.into(),
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new(host_name, port);
if let Some(addr) = host_addr {
config.set_host_addr(addr);
}
config.ssl_mode(ssl_mode);
let node = NodeInfo {
config,
conn_info: compute::ConnectInfo {
host_addr,
host,
port,
ssl_mode,
},
aux: body.aux,
};

View File

@@ -6,6 +6,7 @@ use std::str::FromStr;
use std::sync::Arc;
use futures::TryFutureExt;
use postgres_client::config::SslMode;
use thiserror::Error;
use tokio_postgres::Client;
use tracing::{Instrument, error, info, info_span, warn};
@@ -14,6 +15,7 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
@@ -24,9 +26,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
use crate::{compute, scram};
#[derive(Debug, Error)]
enum MockApiError {
@@ -87,8 +89,7 @@ impl MockControlPlane {
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
} else {
warn!("user '{role}' does not exist");
None
@@ -170,25 +171,23 @@ impl MockControlPlane {
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let port = self.endpoint.port().unwrap_or(5432);
let mut config = match self.endpoint.host_str() {
None => {
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
config
}
Some(host) => {
let mut config = compute::ConnCfg::new(host.to_string(), port);
if let Ok(addr) = IpAddr::from_str(host) {
config.set_host_addr(addr);
}
config
}
let conn_info = match self.endpoint.host_str() {
None => ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
host: "localhost".into(),
port,
ssl_mode: SslMode::Disable,
},
Some(host) => ConnectInfo {
host_addr: IpAddr::from_str(host).ok(),
host: host.into(),
port,
ssl_mode: SslMode::Disable,
},
};
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
config,
conn_info,
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -266,12 +265,3 @@ impl super::ControlPlaneApi for MockControlPlane {
self.do_wake_compute().map_ok(Cached::new_uncached).await
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {
let text = input.strip_prefix("md5")?;
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
Some(bytes)
}

View File

@@ -11,8 +11,8 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
@@ -39,10 +39,6 @@ pub mod mgmt;
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum AuthSecret {
#[cfg(any(test, feature = "testing"))]
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
@@ -63,13 +59,9 @@ pub(crate) struct AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
pub(crate) struct NodeInfo {
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub(crate) config: compute::ConnCfg,
pub(crate) conn_info: compute::ConnectInfo,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
@@ -79,26 +71,14 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
auth: &compute::AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(ctx, self.aux.clone(), config, user_info)
self.conn_info
.connect(ctx, self.aux.clone(), auth, config, user_info)
.await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {
self.config.reuse_password(other.config);
}
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
};
}
}
#[derive(Copy, Clone, Default)]

View File

@@ -610,11 +610,11 @@ pub enum RedisEventsCount {
BranchCreated,
ProjectCreated,
CancelSession,
PasswordUpdate,
AllowedIpsUpdate,
AllowedVpcEndpointIdsUpdateForProjects,
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
BlockPublicOrVpcAccessUpdate,
InvalidateRole,
InvalidateEndpoint,
InvalidateProject,
InvalidateProjects,
InvalidateOrg,
}
pub struct ThreadPoolWorkers(usize);

View File

@@ -2,8 +2,8 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
use crate::auth::backend::ComputeUserInfo;
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
@@ -13,7 +13,6 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
@@ -48,8 +47,6 @@ pub(crate) trait ConnectMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
@@ -58,24 +55,17 @@ pub(crate) trait ComputeConnectBackend {
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
fn get_keys(&self) -> &ComputeCredentialKeys;
}
pub(crate) struct TcpMechanism<'a> {
pub(crate) params_compat: bool,
/// KV-dictionary with PostgreSQL connection params.
pub(crate) params: &'a StartupMessageParams,
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) user_info: ComputeUserInfo,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism<'_> {
impl ConnectMechanism for TcpMechanism {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -90,13 +80,12 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params, self.params_compat);
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(
node_info
.connect(ctx, &self.auth, config, self.user_info.clone())
.await,
)
}
}
@@ -114,12 +103,9 @@ where
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let mut node_info =
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
@@ -155,14 +141,9 @@ where
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
debug!("compute node's state has likely changed; requesting a wake-up");
let old_node_info = invalidate_cache(node_info);
invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);
node_info
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
};
// now that we have a new node, try connect to it repeatedly.

View File

@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
@@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion {
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
@@ -66,32 +88,38 @@ macro_rules! read {
}};
}
/// Returns true if TLS is supported.
///
/// This is not cancel safe.
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let payload = StartupHeader {
len: 8.into(),
version: NEGOTIATE_SSL_CODE,
};
stream.write_all(payload.as_bytes()).await?;
stream.flush().await?;
// we expect back either `S` or `N` as a single byte.
let mut res = *b"0";
stream.read_exact(&mut res).await?;
debug_assert!(
res == *b"S" || res == *b"N",
"unexpected SSL negotiation response: {}",
char::from(res[0]),
);
// S for SSL.
Ok(res == *b"S")
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
@@ -564,9 +592,8 @@ mod tests {
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
#[tokio::test]
async fn reject_large_startup() {

View File

@@ -134,12 +134,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -358,21 +353,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
user_info: creds.info.clone(),
auth: auth_info,
locks: &config.connect_compute_locks,
},
&user_info,

View File

@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
_ => false,
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
}
}
}

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -29,7 +29,6 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -72,13 +71,14 @@ struct ClientConfig<'a> {
hostname: &'a str,
}
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
Ok(crate::tls::postgres_rustls::make_tls_connect(
&self.config,
self.hostname,
)?)
}
}
@@ -497,8 +497,6 @@ impl ConnectMechanism for TestConnectMechanism {
x => panic!("expecting action {x:?}, connect is called instead"),
}
}
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
}
impl TestControlPlaneClient for TestConnectMechanism {
@@ -557,7 +555,12 @@ impl TestControlPlaneClient for TestConnectMechanism {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new("test".to_owned(), 5432),
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
ssl_mode: SslMode::Disable,
host_addr: None,
},
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -581,7 +584,10 @@ fn helper_create_connect_info(
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
},
)
}

View File

@@ -140,12 +140,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -2,11 +2,9 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -27,11 +25,8 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -49,11 +44,6 @@ impl RedisKVClient {
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {

View File

@@ -3,12 +3,12 @@ use std::sync::Arc;
use futures::StreamExt;
use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use tokio_util::sync::CancellationToken;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -27,42 +27,37 @@ struct NotificationHeader<'a> {
topic: &'a str,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(tag = "topic", content = "data")]
pub(crate) enum Notification {
enum Notification {
#[serde(
rename = "/allowed_ips_updated",
rename = "/account_settings_update",
alias = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
AccountSettingsUpdate(InvalidateAccount),
#[serde(
rename = "/block_public_or_vpc_access_updated",
rename = "/endpoint_settings_update",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
EndpointSettingsUpdate(InvalidateEndpoint),
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
rename = "/project_settings_update",
alias = "/allowed_ips_updated",
alias = "/block_public_or_vpc_access_updated",
alias = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
ProjectSettingsUpdate(InvalidateProject),
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
rename = "/role_setting_update",
alias = "/password_updated",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
)]
PasswordUpdate { password_update: PasswordUpdate },
RoleSettingUpdate(InvalidateRole),
#[serde(
other,
@@ -72,28 +67,56 @@ pub(crate) enum Notification {
UnknownTopic,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateEndpoint {
EndpointId(EndpointIdInt),
EndpointIds(Vec<EndpointIdInt>),
}
impl std::ops::Deref for InvalidateEndpoint {
type Target = [EndpointIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::EndpointId(id) => std::slice::from_ref(id),
Self::EndpointIds(ids) => ids,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateProject {
ProjectId(ProjectIdInt),
ProjectIds(Vec<ProjectIdInt>),
}
impl std::ops::Deref for InvalidateProject {
type Target = [ProjectIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::ProjectId(id) => std::slice::from_ref(id),
Self::ProjectIds(ids) => ids,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
account_id: AccountIdInt,
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum InvalidateAccount {
AccountId(AccountIdInt),
AccountIds(Vec<AccountIdInt>),
}
impl std::ops::Deref for InvalidateAccount {
type Target = [AccountIdInt];
fn deref(&self) -> &Self::Target {
match self {
Self::AccountId(id) => std::slice::from_ref(id),
Self::AccountIds(ids) => ids,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct InvalidateRole {
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
@@ -118,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -177,41 +190,29 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
tracing::debug!(?msg, "received a message");
match msg {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
Notification::RoleSettingUpdate { .. }
| Notification::EndpointSettingsUpdate { .. }
| Notification::ProjectSettingsUpdate { .. }
| Notification::AccountSettingsUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedIpsUpdate);
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
} else if matches!(
msg,
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
let m = &Metrics::get().proxy.redis_events_count;
match msg {
Notification::RoleSettingUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateRole);
}
Notification::EndpointSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateEndpoint);
}
Notification::ProjectSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateProject);
}
Notification::AccountSettingsUpdate { .. } => {
m.inc(RedisEventsCount::InvalidateOrg);
}
Notification::UnknownTopic => {}
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
@@ -233,30 +234,23 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
match msg {
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate { project_id },
}
| Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
} => cache.invalidate_endpoint_access_for_project(project_id),
Notification::AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
} => cache.invalidate_endpoint_access_for_org(account_id),
Notification::AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects:
AllowedVpcEndpointsUpdatedForProjects { project_ids },
} => {
for project in project_ids {
cache.invalidate_endpoint_access_for_project(project);
}
}
Notification::PasswordUpdate {
password_update:
PasswordUpdate {
project_id,
role_name,
},
} => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::EndpointSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access(id)),
Notification::AccountSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
Notification::ProjectSettingsUpdate(ids) => ids
.iter()
.for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
Notification::RoleSettingUpdate(InvalidateRole {
project_id,
role_name,
}) => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::UnknownTopic => unreachable!(),
}
}
@@ -272,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.increment_active_listeners().await;
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -293,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
}
}
@@ -306,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache, region_id);
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
@@ -353,11 +346,32 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate {
project_id: (&project_id).into()
}
}
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
);
Ok(())
}
#[test]
fn parse_multiple_projects() -> anyhow::Result<()> {
let project_id1: ProjectId = "new_project1".into();
let project_id2: ProjectId = "new_project2".into();
let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
let text = json!({
"type": "message",
"topic": "/allowed_vpc_endpoints_updated_for_projects",
"data": data,
"extre_fields": "something"
})
.to_string();
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
(&project_id1).into(),
(&project_id2).into()
]))
);
Ok(())
@@ -379,12 +393,10 @@ mod tests {
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(
result,
Notification::PasswordUpdate {
password_update: PasswordUpdate {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
}
}
Notification::RoleSettingUpdate(InvalidateRole {
project_id: (&project_id).into(),
role_name: (&role_name).into(),
})
);
Ok(())

View File

@@ -23,7 +23,6 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
@@ -305,12 +304,13 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let (key, jwk) = create_random_jwk();
let config = node_info
.config
let mut config = local_backend
.node_info
.conn_info
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
@@ -322,7 +322,7 @@ impl PoolingBackend {
);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(postgres_client::NoTls).await?;
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
drop(pause);
let pid = client.get_process_id();
@@ -336,7 +336,7 @@ impl PoolingBackend {
connection,
key,
conn_id,
node_info.aux.clone(),
local_backend.node_info.aux.clone(),
);
{
@@ -512,19 +512,16 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = (*node_info.config).clone();
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
let mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(mk_tls).await;
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -548,8 +545,6 @@ impl ConnectMechanism for TokioMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
@@ -573,20 +568,20 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.config.get_host_addr();
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.config.get_port();
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -609,8 +604,6 @@ impl ConnectMechanism for HyperMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(

View File

@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {

View File

@@ -417,12 +417,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
ctx.set_user_agent(
request
@@ -462,12 +457,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request

View File

@@ -2,10 +2,11 @@ use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
use rustls::ClientConfig;
use rustls::pki_types::ServerName;
use rustls::pki_types::{InvalidDnsNameError, ServerName};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::config::ComputeConfig;
mod private {
use std::future::Future;
use std::io;
@@ -123,36 +124,27 @@ mod private {
}
}
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
impl<S> MakeTlsConnect<S> for ComputeConfig
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = rustls::pki_types::InvalidDnsNameError;
type Error = InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
make_tls_connect(&self.tls, hostname)
}
}
pub fn make_tls_connect(
tls: &Arc<rustls::ClientConfig>,
hostname: &str,
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: tls.clone().into(),
})
})
}

View File

@@ -8,8 +8,8 @@ use std::error::Error as _;
use http_utils::error::HttpErrorBody;
use reqwest::{IntoUrl, Method, StatusCode};
use safekeeper_api::models::{
self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest,
TimelineStatus,
self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization,
TimelineCreateRequest, TimelineStatus,
};
use utils::id::{NodeId, TenantId, TimelineId};
use utils::logging::SecretString;
@@ -183,6 +183,12 @@ impl Client {
self.get(&uri).await
}
pub async fn status(&self) -> Result<SafekeeperStatus> {
let uri = format!("{}/v1/status", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn utilization(&self) -> Result<SafekeeperUtilization> {
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;

View File

@@ -395,6 +395,8 @@ pub enum TimelineError {
Cancelled(TenantTimelineId),
#[error("Timeline {0} was not found in global map")]
NotFound(TenantTimelineId),
#[error("Timeline {0} has been deleted")]
Deleted(TenantTimelineId),
#[error("Timeline {0} creation is in progress")]
CreationInProgress(TenantTimelineId),
#[error("Timeline {0} exists on disk, but wasn't loaded on startup")]

View File

@@ -78,7 +78,13 @@ impl GlobalTimelinesState {
Some(GlobalMapTimeline::CreationInProgress) => {
Err(TimelineError::CreationInProgress(*ttid))
}
None => Err(TimelineError::NotFound(*ttid)),
None => {
if self.has_tombstone(ttid) {
Err(TimelineError::Deleted(*ttid))
} else {
Err(TimelineError::NotFound(*ttid))
}
}
}
}

View File

@@ -0,0 +1 @@
ALTER TABLE nodes DROP COLUMN lifecycle;

View File

@@ -0,0 +1 @@
ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active';

View File

@@ -907,6 +907,42 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
json_response(StatusCode::OK, state.service.node_delete(node_id).await?)
}
async fn handle_tombstone_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let mut nodes = state.service.tombstone_list().await?;
nodes.sort_by_key(|n| n.get_id());
let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::<Vec<_>>();
json_response(StatusCode::OK, api_nodes)
}
async fn handle_tombstone_delete(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(
StatusCode::OK,
state.service.tombstone_delete(node_id).await?,
)
}
async fn handle_node_configure(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
@@ -2062,6 +2098,20 @@ pub fn make_router(
.post("/debug/v1/node/:node_id/drop", |r| {
named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop"))
})
.delete("/debug/v1/tombstone/:node_id", |r| {
named_request_span(
r,
handle_tombstone_delete,
RequestName("debug_v1_tombstone_delete"),
)
})
.get("/debug/v1/tombstone", |r| {
named_request_span(
r,
handle_tombstone_list,
RequestName("debug_v1_tombstone_list"),
)
})
.post("/debug/v1/tenant/:tenant_id/import", |r| {
named_request_span(
r,

View File

@@ -2,7 +2,7 @@ use std::str::FromStr;
use std::time::Duration;
use pageserver_api::controller_api::{
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest,
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest,
NodeSchedulingPolicy, TenantLocateResponseShard,
};
use pageserver_api::shard::TenantShardId;
@@ -29,6 +29,7 @@ pub(crate) struct Node {
availability: NodeAvailability,
scheduling: NodeSchedulingPolicy,
lifecycle: NodeLifecycle,
listen_http_addr: String,
listen_http_port: u16,
@@ -228,6 +229,7 @@ impl Node {
listen_pg_addr,
listen_pg_port,
scheduling: NodeSchedulingPolicy::Active,
lifecycle: NodeLifecycle::Active,
availability: NodeAvailability::Offline,
availability_zone_id,
use_https,
@@ -239,6 +241,7 @@ impl Node {
NodePersistence {
node_id: self.id.0 as i64,
scheduling_policy: self.scheduling.into(),
lifecycle: self.lifecycle.into(),
listen_http_addr: self.listen_http_addr.clone(),
listen_http_port: self.listen_http_port as i32,
listen_https_port: self.listen_https_port.map(|x| x as i32),
@@ -263,6 +266,7 @@ impl Node {
availability: NodeAvailability::Offline,
scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy)
.expect("Bad scheduling policy in DB"),
lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"),
listen_http_addr: np.listen_http_addr,
listen_http_port: np.listen_http_port as u16,
listen_https_port: np.listen_https_port.map(|x| x as u16),

View File

@@ -19,7 +19,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use itertools::Itertools;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy,
AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
};
use pageserver_api::models::{ShardImportStatus, TenantConfig};
@@ -102,6 +102,7 @@ pub(crate) enum DatabaseOperation {
UpdateNode,
DeleteNode,
ListNodes,
ListTombstones,
BeginShardSplit,
CompleteShardSplit,
AbortShardSplit,
@@ -357,6 +358,8 @@ impl Persistence {
}
/// When a node is first registered, persist it before using it for anything
/// If the provided node_id already exists, it will be error.
/// The common case is when a node marked for deletion wants to register.
pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> {
let np = &node.to_persistent();
self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| {
@@ -373,19 +376,41 @@ impl Persistence {
/// At startup, populate the list of nodes which our shards may be placed on
pub(crate) async fn list_nodes(&self) -> DatabaseResult<Vec<NodePersistence>> {
let nodes: Vec<NodePersistence> = self
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListNodes, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_nodes: loaded {} nodes", nodes.len());
tracing::info!("list_nodes: loaded {} nodes", result.len());
Ok(nodes)
Ok(result)
}
pub(crate) async fn list_tombstones(&self) -> DatabaseResult<Vec<NodePersistence>> {
use crate::schema::nodes::dsl::*;
let result: Vec<NodePersistence> = self
.with_measured_conn(DatabaseOperation::ListTombstones, move |conn| {
Box::pin(async move {
Ok(crate::schema::nodes::table
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.load::<NodePersistence>(conn)
.await?)
})
})
.await?;
tracing::info!("list_tombstones: loaded {} nodes", result.len());
Ok(result)
}
pub(crate) async fn update_node<V>(
@@ -404,6 +429,7 @@ impl Persistence {
Box::pin(async move {
let updated = diesel::update(nodes)
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
.set(values)
.execute(conn)
.await?;
@@ -447,6 +473,57 @@ impl Persistence {
.await
}
/// Tombstone is a special state where the node is not deleted from the database,
/// but it is not available for usage.
/// The main reason for it is to prevent the flaky node to register.
pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.update_node(
del_node_id,
lifecycle.eq(String::from(NodeLifecycle::Deleted)),
)
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
// You can hard delete a node only if it has a tombstone.
// So we need to check if the node has lifecycle set to deleted.
let node_to_delete = nodes
.filter(node_id.eq(del_node_id.0 as i64))
.first::<NodePersistence>(conn)
.await
.optional()?;
if let Some(np) = node_to_delete {
let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| {
DatabaseError::Logical(format!(
"Node {} has invalid lifecycle: {}",
del_node_id, e
))
})?;
if lc != NodeLifecycle::Deleted {
return Err(DatabaseError::Logical(format!(
"Node {} was not soft deleted before, cannot hard delete it",
del_node_id
)));
}
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
}
Ok(())
})
})
.await
}
/// At startup, load the high level state for shards, such as their config + policy. This will
/// be enriched at runtime with state discovered on pageservers.
///
@@ -543,21 +620,6 @@ impl Persistence {
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
Box::pin(async move {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)
.await?;
Ok(())
})
})
.await
}
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
/// the node that called /re-attach.
@@ -571,6 +633,20 @@ impl Persistence {
let updated = self
.with_measured_conn(DatabaseOperation::ReAttach, move |conn| {
Box::pin(async move {
// Check if the node is not marked as deleted
let deleted_node: i64 = nodes
.filter(node_id.eq(input_node_id.0 as i64))
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
.count()
.get_result(conn)
.await?;
if deleted_node > 0 {
return Err(DatabaseError::Logical(format!(
"Node {} is marked as deleted, re-attach is not allowed",
input_node_id
)));
}
let rows_updated = diesel::update(tenant_shards)
.filter(generation_pageserver.eq(input_node_id.0 as i64))
.set(generation.eq(generation + 1))
@@ -2048,6 +2124,7 @@ pub(crate) struct NodePersistence {
pub(crate) listen_pg_port: i32,
pub(crate) availability_zone_id: String,
pub(crate) listen_https_port: Option<i32>,
pub(crate) lifecycle: String,
}
/// Tenant metadata health status that are stored durably.

View File

@@ -33,6 +33,7 @@ diesel::table! {
listen_pg_port -> Int4,
availability_zone_id -> Varchar,
listen_https_port -> Nullable<Int4>,
lifecycle -> Varchar,
}
}

View File

@@ -166,6 +166,7 @@ enum NodeOperations {
Register,
Configure,
Delete,
DeleteTombstone,
}
/// The leadership status for the storage controller process.
@@ -1107,7 +1108,8 @@ impl Service {
observed
}
/// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers.
/// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us
/// tenants from pageservers.
///
/// This is safe to run in the background, because if we don't have this TenantShardId in our map of
/// tenants, then it is probably something incompletely deleted before: we will not fight with any
@@ -6210,7 +6212,11 @@ impl Service {
}
}
pausable_failpoint!("shard-split-pre-complete");
fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict(
"failpoint".to_string()
)));
pausable_failpoint!("shard-split-pre-complete-pause");
// TODO: if the pageserver restarted concurrently with our split API call,
// the actual generation of the child shard might differ from the generation
@@ -6232,6 +6238,15 @@ impl Service {
let (response, child_locations, waiters) =
self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size);
// Notify all page servers to detach and clean up the old shards because they will no longer
// be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent
// Pageserver re-attach/startup.
let shards_to_cleanup = targets
.iter()
.map(|target| (target.parent_id, target.node.get_id()))
.collect();
self.cleanup_locations(shards_to_cleanup).await;
// Send compute notifications for all the new shards
let mut failed_notifications = Vec::new();
for (child_id, child_ps, stripe_size) in child_locations {
@@ -6909,7 +6924,7 @@ impl Service {
/// detaching or deleting it on pageservers. We do not try and re-schedule any
/// tenants that were on this node.
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
self.persistence.delete_node(node_id).await?;
self.persistence.set_tombstone(node_id).await?;
let mut locked = self.inner.write().unwrap();
@@ -7033,9 +7048,10 @@ impl Service {
// That is safe because in Service::spawn we only use generation_pageserver if it refers to a node
// that exists.
// 2. Actually delete the node from the database and from in-memory state
// 2. Actually delete the node from in-memory state and set tombstone to the database
// for preventing the node to register again.
tracing::info!("Deleting node from database");
self.persistence.delete_node(node_id).await?;
self.persistence.set_tombstone(node_id).await?;
Ok(())
}
@@ -7054,6 +7070,35 @@ impl Service {
Ok(nodes)
}
pub(crate) async fn tombstone_list(&self) -> Result<Vec<Node>, ApiError> {
self.persistence
.list_tombstones()
.await?
.into_iter()
.map(|np| Node::from_persistent(np, false))
.collect::<Result<Vec<_>, _>>()
.map_err(ApiError::InternalServerError)
}
pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> {
let _node_lock = trace_exclusive_lock(
&self.node_op_locks,
node_id,
NodeOperations::DeleteTombstone,
)
.await;
if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) {
self.persistence.delete_node(node_id).await?;
Ok(())
} else {
Err(ApiError::Conflict(format!(
"Node {} is in use, consider using tombstone API first",
node_id
)))
}
}
pub(crate) async fn get_node(&self, node_id: NodeId) -> Result<Node, ApiError> {
self.inner
.read()
@@ -7224,7 +7269,25 @@ impl Service {
};
match registration_status {
RegistrationStatus::New => self.persistence.insert_node(&new_node).await?,
RegistrationStatus::New => {
self.persistence.insert_node(&new_node).await.map_err(|e| {
if matches!(
e,
crate::persistence::DatabaseError::Query(
diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)
)
) {
// The node can be deleted by tombstone API, and not show up in the list of nodes.
// If you see this error, check tombstones first.
ApiError::Conflict(format!("Node {} is already exists", new_node.get_id()))
} else {
ApiError::from(e)
}
})?;
}
RegistrationStatus::NeedUpdate => {
self.persistence
.update_node_on_registration(

View File

@@ -2054,6 +2054,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def tombstone_delete(self, node_id):
log.info(f"tombstone_delete({node_id})")
self.request(
"DELETE",
f"{self.api}/debug/v1/tombstone/{node_id}",
headers=self.headers(TokenScope.ADMIN),
)
def node_drain(self, node_id):
log.info(f"node_drain({node_id})")
self.request(
@@ -2110,6 +2118,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
)
return response.json()
def tombstone_list(self):
response = self.request(
"GET",
f"{self.api}/debug/v1/tombstone",
headers=self.headers(TokenScope.ADMIN),
)
return response.json()
def tenant_shard_dump(self):
"""
Debug listing API: dumps the internal map of tenant shards

View File

@@ -87,6 +87,9 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build
# Set up pageserver for import
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
}
env = neon_env_builder.init_start()
env.pageserver.tenant_create(tenant)

View File

@@ -30,6 +30,7 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -198,6 +199,7 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder)
env.pageserver.allowed_errors.extend(
[
".*Timeline.*was cancelled.*",
".*Timeline.*has been deleted.*",
".*Timeline.*was not found.*",
]
)

View File

@@ -1836,3 +1836,90 @@ def test_sharding_gc(
shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"])
log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}")
assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn
def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder):
"""
Check that PageServer only deletes old shards after the split is committed such that it doesn't
have to download a lot of files during abort.
"""
DBNAME = "regression"
init_shard_count = 4
neon_env_builder.num_pageservers = init_shard_count
stripe_size = 32
env = neon_env_builder.init_start(
initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size
)
env.storage_controller.allowed_errors.extend(
[
# All split failures log a warning when they enqueue the abort operation
".*Enqueuing background abort.*",
# Tolerate any error logs that mention a failpoint
".*failpoint.*",
]
)
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
# Write some initial data.
endpoint.safe_psql(f"CREATE DATABASE {DBNAME}")
endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);")
for _ in range(1000):
endpoint.safe_psql(
"INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False
)
# Record how many bytes we've downloaded before the split.
def collect_downloaded_bytes() -> list[float | None]:
downloaded_bytes = []
for page_server in env.pageservers:
metric = page_server.http_client().get_metric_value(
"pageserver_remote_ondemand_downloaded_bytes_total"
)
downloaded_bytes.append(metric)
return downloaded_bytes
downloaded_bytes_before = collect_downloaded_bytes()
# Attempt to split the tenant, but fail the split before it completes.
env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)"))
with pytest.raises(StorageControllerApiException):
env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16)
# Wait until split is aborted.
def check_split_is_aborted():
tenants = env.storage_controller.tenant_list()
assert len(tenants) == 1
shards = tenants[0]["shards"]
assert len(shards) == 4
for shard in shards:
assert not shard["is_splitting"]
assert not shard["is_reconciling"]
# Make sure all new shards have been deleted.
valid_shards = 0
for ps in env.pageservers:
for tenant_dir in os.listdir(ps.workdir / "tenants"):
try:
tenant_shard_id = TenantShardId.parse(tenant_dir)
valid_shards += 1
assert tenant_shard_id.shard_count == 4
except ValueError:
log.info(f"{tenant_dir} is not valid tenant shard id")
assert valid_shards >= 4
wait_until(check_split_is_aborted)
endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False)
# Make sure we didn't download anything following the aborted split.
downloaded_bytes_after = collect_downloaded_bytes()
assert downloaded_bytes_before == downloaded_bytes_after
endpoint.stop_and_destroy()

View File

@@ -2956,7 +2956,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.allowed_errors.extend(
[".*Unexpected child shard count.*", ".*Enqueuing background abort.*"]
)
pause_failpoint = "shard-split-pre-complete"
pause_failpoint = "shard-split-pre-complete-pause"
env.storage_controller.configure_failpoints((pause_failpoint, "pause"))
split_fut = executor.submit(
@@ -3003,7 +3003,7 @@ def test_storage_controller_leadership_transfer_during_split(
env.storage_controller.request(
"PUT",
f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints",
json=[{"name": "shard-split-pre-complete", "actions": "off"}],
json=[{"name": pause_failpoint, "actions": "off"}],
headers=env.storage_controller.headers(TokenScope.ADMIN),
)
@@ -3093,6 +3093,58 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
wait_until(reconfigure_node_again)
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_pageservers = 3
env = neon_env_builder.init_start()
def assert_nodes_count(n: int):
nodes = env.storage_controller.node_list()
assert len(nodes) == n
# Nodes count must remain the same before deletion
assert_nodes_count(3)
ps = env.pageservers[0]
env.storage_controller.node_delete(ps.id)
# After deletion, the node count must be reduced
assert_nodes_count(2)
# Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
log.info("Restarting tombstoned pageserver...")
ps.stop()
ps_start_fut = executor.submit(lambda: ps.start(await_active=False))
# After deleted pageserver restart, the node count must remain the same
assert_nodes_count(2)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id
env.storage_controller.tombstone_delete(ps.id)
tombstones = env.storage_controller.tombstone_list()
assert len(tombstones) == 0
# Wait for the pageserver start operation to complete.
# If it fails with an exception, we try restarting the pageserver since the failure
# may be due to the storage controller refusing to register the node.
# However, if we get a TimeoutError that means the pageserver is completely hung,
# which is an unexpected failure mode that we'll let propagate up.
try:
ps_start_fut.result(timeout=20)
except TimeoutError:
raise
except Exception:
log.info("Restarting deleted pageserver...")
ps.restart()
# Finally, the node can be registered again after tombstone is deleted
wait_until(lambda: assert_nodes_count(3))
def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder):
"""
The storage controller is meant to handle the case where a timeline CRUD operation races

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import pytest
import requests
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import (
PAGESERVER_GLOBAL_METRICS,
@@ -299,6 +299,65 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde
assert post_detach_samples == set()
def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder):
"""Tests that when a timeline is offloaded, the tenant specific metrics are not left behind"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
tenant_1, _ = env.create_tenant()
timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1)
timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1)
endpoint_tenant1 = env.endpoints.create_start(
"test_metrics_removed_after_offload_1", tenant_id=tenant_1
)
endpoint_tenant2 = env.endpoints.create_start(
"test_metrics_removed_after_offload_2", tenant_id=tenant_1
)
for endpoint in [endpoint_tenant1, endpoint_tenant2]:
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("CREATE TABLE t(key int primary key, value text)")
cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'")
cur.execute("SELECT sum(key) FROM t")
assert cur.fetchone() == (5000050000,)
endpoint.stop()
def get_ps_metric_samples_for_timeline(
tenant_id: TenantId, timeline_id: TimelineId
) -> list[Sample]:
ps_metrics = env.pageserver.http_client().get_metrics()
samples = []
for metric_name in ps_metrics.metrics:
for sample in ps_metrics.query_all(
name=metric_name,
filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)},
):
samples.append(sample)
return samples
for timeline in [timeline_1, timeline_2]:
pre_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}"
env.pageserver.http_client().timeline_archival_config(
tenant_1,
timeline,
state=TimelineArchivalState.ARCHIVED,
)
env.pageserver.http_client().timeline_offload(tenant_1, timeline)
post_offload_samples = set(
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
)
assert post_offload_samples == set()
def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()

View File

@@ -433,6 +433,7 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)
@@ -1934,6 +1935,7 @@ def test_membership_api(neon_env_builder: NeonEnvBuilder):
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was not found in global map.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was cancelled and cannot be used anymore.*",
]
)