Merge branch 'communicator-rewrite' into quantumish/lfc-resize-static-shmem

This commit is contained in:
quantumish
2025-07-24 19:44:21 -07:00
212 changed files with 5705 additions and 6624 deletions

View File

@@ -21,13 +21,14 @@ platforms = [
# "x86_64-apple-darwin",
# "x86_64-pc-windows-msvc",
]
[final-excludes]
workspace-members = [
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
"vm_monitor",
# subzero-core is a stub crate that should be excluded from workspace-hack
"subzero-core",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"compute_api",

View File

@@ -0,0 +1,28 @@
name: 'Prepare current job for subzero'
description: >
Set git token to access `neondatabase/subzero` from cargo build,
and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI
inputs:
token:
description: 'GitHub token with access to neondatabase/subzero'
required: true
runs:
using: "composite"
steps:
- name: Set git token for neondatabase/subzero
uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0
env:
SUBZERO_ACCESS_TOKEN: ${{ inputs.token }}
with:
main: |
git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a
post: |
git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
- name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable
shell: bash -euxo pipefail {0}
run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV}

View File

@@ -86,6 +86,10 @@ jobs:
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Set pg 14 revision for caching
id: pg_v14_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT
@@ -116,7 +120,7 @@ jobs:
ARCH: ${{ inputs.arch }}
SANITIZERS: ${{ inputs.sanitizers }}
run: |
CARGO_FLAGS="--locked --features testing"
CARGO_FLAGS="--locked --features testing,rest_broker"
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_PROFILE=""

View File

@@ -46,6 +46,10 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Cache cargo deps
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0

View File

@@ -54,6 +54,10 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Install build dependencies
run: |

View File

@@ -632,6 +632,8 @@ jobs:
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
provenance: false
push: true
pull: true

View File

@@ -72,6 +72,7 @@ jobs:
check-macos-build:
needs: [ check-permissions, files-changed ]
uses: ./.github/workflows/build-macos.yml
secrets: inherit
with:
pg_versions: ${{ needs.files-changed.outputs.postgres_changes }}
rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }}

5
.gitignore vendored
View File

@@ -27,9 +27,14 @@ docker-compose/docker-compose-parallel.yml
*.o
*.so
*.Po
*.pid
# pgindent typedef lists
*.list
# Node
**/node_modules/
# various files for local testing
/proxy/.subzero
local_proxy.json

553
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -35,7 +35,6 @@ members = [
"libs/pq_proto",
"libs/tenant_size_model",
"libs/metrics",
"libs/neonart",
"libs/postgres_connection",
"libs/remote_storage",
"libs/tracing-utils",
@@ -50,6 +49,7 @@ members = [
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
"proxy/subzero_core",
]
[workspace.package]
@@ -144,10 +144,10 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
once_cell = "1.13"
opentelemetry = "0.27"
opentelemetry_sdk = "0.27"
opentelemetry-otlp = { version = "0.27", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.27"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] }
opentelemetry-semantic-conventions = "0.30"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -160,11 +160,13 @@ procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.13.5"
prost-types = "0.13.5"
rand = "0.8"
rand = "0.9"
# Remove after p256 is updated to 0.14.
rand_core = "=0.6"
redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_27"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_30"] }
reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
@@ -214,15 +216,12 @@ tonic = { version = "0.13.1", default-features = false, features = ["channel", "
tonic-reflection = { version = "0.13.1", features = ["server"] }
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }
# This revision uses opentelemetry 0.27. There's no tag for it.
tower-otel = { git = "https://github.com/mattiapenati/tower-otel", rev = "56a7321053bcb72443888257b622ba0d43a11fcd" }
tower-otel = { version = "0.6", features = ["axum"] }
tower-service = "0.3.3"
tracing = "0.1"
tracing-error = "0.2"
tracing-log = "0.2"
tracing-opentelemetry = "0.28"
tracing-opentelemetry = "0.31"
tracing-serde = "0.2.0"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"

View File

@@ -63,7 +63,14 @@ WORKDIR /home/nonroot
COPY --chown=nonroot . .
RUN cargo chef prepare --recipe-path recipe.json
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \
fi \
&& cargo chef prepare --recipe-path recipe.json
# Main build image
FROM $REPOSITORY/$IMAGE:$TAG AS build
@@ -71,20 +78,33 @@ WORKDIR /home/nonroot
ARG GIT_VERSION=local
ARG BUILD_TAG
ARG ADDITIONAL_RUSTFLAGS=""
ENV CARGO_FEATURES="default"
# 3. Build cargo dependencies. Note that this step doesn't depend on anything else than
# `recipe.json`, so the layer can be reused as long as none of the dependencies change.
COPY --from=plan /home/nonroot/recipe.json recipe.json
RUN set -e \
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json
# Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build'
# layer, and the cargo dependencies built in the previous step.
COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install
COPY --chown=nonroot . .
COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml
COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock
RUN set -e \
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \
--bin pagectl \

View File

@@ -27,7 +27,10 @@ fail.workspace = true
flate2.workspace = true
futures.workspace = true
http.workspace = true
http-body-util.workspace = true
hostname-validator = "1.1"
hyper.workspace = true
hyper-util.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
@@ -44,6 +47,7 @@ postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
ring = "0.17"
scopeguard.workspace = true
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true

View File

@@ -138,6 +138,12 @@ struct Cli {
/// Run in development mode, skipping VM-specific operations like process termination
#[arg(long, action = clap::ArgAction::SetTrue)]
pub dev: bool,
#[arg(long)]
pub pg_init_timeout: Option<u64>,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
pub lakebase_mode: bool,
}
impl Cli {
@@ -188,7 +194,7 @@ fn main() -> Result<()> {
.build()?;
let _rt_guard = runtime.enter();
runtime.block_on(init(cli.dev))?;
let tracing_provider = init(cli.dev)?;
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
@@ -219,6 +225,8 @@ fn main() -> Result<()> {
installed_extensions_collection_interval: Arc::new(AtomicU64::new(
cli.installed_extensions_collection_interval,
)),
pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs),
lakebase_mode: cli.lakebase_mode,
},
config,
)?;
@@ -227,11 +235,11 @@ fn main() -> Result<()> {
scenario.teardown();
deinit_and_exit(exit_code);
deinit_and_exit(tracing_provider, exit_code);
}
async fn init(dev_mode: bool) -> Result<()> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL).await?;
fn init(dev_mode: bool) -> Result<Option<tracing_utils::Provider>> {
let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?;
thread::spawn(move || {
@@ -242,7 +250,7 @@ async fn init(dev_mode: bool) -> Result<()> {
info!("compute build_tag: {}", &BUILD_TAG.to_string());
Ok(())
Ok(provider)
}
fn get_config(cli: &Cli) -> Result<ComputeConfig> {
@@ -267,25 +275,27 @@ fn get_config(cli: &Cli) -> Result<ComputeConfig> {
}
}
fn deinit_and_exit(exit_code: Option<i32>) -> ! {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
tracing_utils::shutdown_tracing();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
fn deinit_and_exit(tracing_provider: Option<tracing_utils::Provider>, exit_code: Option<i32>) -> ! {
if let Some(p) = tracing_provider {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
_ = p.shutdown();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
}
}
info!("shutting down");

View File

@@ -0,0 +1,98 @@
//! Client for making request to a running Postgres server's communicator control socket.
//!
//! The storage communicator process that runs inside Postgres exposes an HTTP endpoint in
//! a Unix Domain Socket in the Postgres data directory. This provides access to it.
use std::path::Path;
use anyhow::Context;
use hyper::client::conn::http1::SendRequest;
use hyper_util::rt::TokioIo;
/// Name of the socket within the Postgres data directory. This better match that in
/// `pgxn/neon/communicator/src/lib.rs`.
const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket";
/// Open a connection to the communicator's control socket, prepare to send requests to it
/// with hyper.
pub async fn connect_communicator_socket<B>(pgdata: &Path) -> anyhow::Result<SendRequest<B>>
where
B: hyper::body::Body + 'static + Send,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let socket_path = pgdata.join(NEON_COMMUNICATOR_SOCKET_NAME);
let socket_path_len = socket_path.display().to_string().len();
// There is a limit of around 100 bytes (108 on Linux?) on the length of the path to a
// Unix Domain socket. The limit is on the connect(2) function used to open the
// socket, not on the absolute path itself. Postgres changes the current directory to
// the data directory and uses a relative path to bind to the socket, and the relative
// path "./neon-communicator.socket" is always short, but when compute_ctl needs to
// open the socket, we need to use a full path, which can be arbitrarily long.
//
// There are a few ways we could work around this:
//
// 1. Change the current directory to the Postgres data directory and use a relative
// path in the connect(2) call. That's problematic because the current directory
// applies to the whole process. We could change the current directory early in
// compute_ctl startup, and that might be a good idea anyway for other reasons too:
// it would be more robust if the data directory is moved around or unlinked for
// some reason, and you would be less likely to accidentally litter other parts of
// the filesystem with e.g. temporary files. However, that's a pretty invasive
// change.
//
// 2. On Linux, you could open() the data directory, and refer to the the socket
// inside it as "/proc/self/fd/<fd>/neon-communicator.socket". But that's
// Linux-only.
//
// 3. Create a symbolic link to the socket with a shorter path, and use that.
//
// We use the symbolic link approach here. Hopefully the paths we use in production
// are shorter, so that we can open the socket directly, so that this hack is needed
// only in development.
let connect_result = if socket_path_len < 100 {
// We can open the path directly with no hacks.
tokio::net::UnixStream::connect(socket_path).await
} else {
// The path to the socket is too long. Create a symlink to it with a shorter path.
let short_path = std::env::temp_dir().join(format!(
"compute_ctl.short-socket.{}.{}",
std::process::id(),
tokio::task::id()
));
std::os::unix::fs::symlink(&socket_path, &short_path)?;
// Delete the symlink as soon as we have connected to it. There's a small chance
// of leaking if the process dies before we remove it, so try to keep that window
// as small as possible.
scopeguard::defer! {
if let Err(err) = std::fs::remove_file(&short_path) {
tracing::warn!("could not remove symlink \"{}\" created for socket: {}",
short_path.display(), err);
}
}
tracing::info!(
"created symlink \"{}\" for socket \"{}\", opening it now",
short_path.display(),
socket_path.display()
);
tokio::net::UnixStream::connect(&short_path).await
};
let stream = connect_result.context("connecting to communicator control socket")?;
let io = TokioIo::new(stream);
let (request_sender, connection) = hyper::client::conn::http1::handshake(io).await?;
// spawn a task to poll the connection and drive the HTTP state
tokio::spawn(async move {
if let Err(err) = connection.await {
eprintln!("Error in connection: {err}");
}
});
Ok(request_sender)
}

View File

@@ -114,6 +114,11 @@ pub struct ComputeNodeParams {
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: Arc<AtomicU64>,
/// Timeout of PG compute startup in the Init state.
pub pg_init_timeout: Option<Duration>,
pub lakebase_mode: bool,
}
type TaskHandle = Mutex<Option<JoinHandle<()>>>;
@@ -155,6 +160,7 @@ pub struct RemoteExtensionMetrics {
#[derive(Clone, Debug)]
pub struct ComputeState {
pub start_time: DateTime<Utc>,
pub pg_start_time: Option<DateTime<Utc>>,
pub status: ComputeStatus,
/// Timestamp of the last Postgres activity. It could be `None` if
/// compute wasn't used since start.
@@ -192,6 +198,7 @@ impl ComputeState {
pub fn new() -> Self {
Self {
start_time: Utc::now(),
pg_start_time: None,
status: ComputeStatus::Empty,
last_active: None,
error: None,
@@ -737,6 +744,9 @@ impl ComputeNode {
};
_this_entered = start_compute_span.enter();
// Hadron: Record postgres start time (used to enforce pg_init_timeout).
state_guard.pg_start_time.replace(Utc::now());
state_guard.set_status(ComputeStatus::Init, &self.state_changed);
compute_state = state_guard.clone()
}
@@ -1544,7 +1554,7 @@ impl ComputeNode {
.with_context(|| format!("failed to get basebackup@{lsn}"))?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path)?;
update_pg_hba(pgdata_path, None)?;
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
@@ -1849,6 +1859,7 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let lakebase_mode = self.params.lakebase_mode;
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
@@ -1861,7 +1872,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(params, &mut client, lakebase_mode).await {
error!("Failed to run migrations: {}", e);
}
}

View File

@@ -1,10 +1,18 @@
use std::path::Path;
use std::sync::Arc;
use anyhow::Context;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use hyper::{Request, StatusCode};
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::communicator_socket_client::connect_communicator_socket;
use crate::compute::ComputeNode;
use crate::http::JsonResponse;
use crate::metrics::collect;
@@ -31,3 +39,42 @@ pub(in crate::http) async fn get_metrics() -> Response {
.body(Body::from(buffer))
.unwrap()
}
/// Fetch and forward metrics from the Postgres neon extension's metrics
/// exporter that are used by autoscaling-agent.
///
/// The neon extension exposes these metrics over a Unix domain socket
/// in the data directory. That's not accessible directly from the outside
/// world, so we have this endpoint in compute_ctl to expose it
pub(in crate::http) async fn get_autoscaling_metrics(
State(compute): State<Arc<ComputeNode>>,
) -> Result<Response, Response> {
let pgdata = Path::new(&compute.params.pgdata);
// Connect to the communicator process's metrics socket
let mut metrics_client = connect_communicator_socket(pgdata)
.await
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Make a request for /autoscaling_metrics
let request = Request::builder()
.method("GET")
.uri("/autoscaling_metrics")
.header("Host", "localhost") // hyper requires Host, even though the server won't care
.body(Body::from(""))
.unwrap();
let resp = metrics_client
.send_request(request)
.await
.context("fetching metrics from Postgres metrics service")
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Build a response that just forwards the response we got.
let mut response = Response::builder();
response = response.status(resp.status());
if let Some(content_type) = resp.headers().get(CONTENT_TYPE) {
response = response.header(CONTENT_TYPE, content_type);
}
let body = tonic::service::AxumBody::from_stream(resp.into_body().into_data_stream());
Ok(response.body(body).unwrap())
}

View File

@@ -81,8 +81,12 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Server::External {
config, compute_id, ..
} => {
let unauthenticated_router =
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/metrics", get(metrics::get_metrics))
.route(
"/autoscaling_metrics",
get(metrics::get_autoscaling_metrics),
);
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))

View File

@@ -4,6 +4,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
pub mod checker;
pub mod communicator_socket_client;
pub mod config;
pub mod configurator;
pub mod http;

View File

@@ -13,7 +13,9 @@ use tracing_subscriber::prelude::*;
/// set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`. See
/// `tracing-utils` package description.
///
pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
pub fn init_tracing_and_logging(
default_log_level: &str,
) -> anyhow::Result<Option<tracing_utils::Provider>> {
// Initialize Logging
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level));
@@ -24,8 +26,9 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result
.with_writer(std::io::stderr);
// Initialize OpenTelemetry
let otlp_layer =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()).await;
let provider =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default());
let otlp_layer = provider.as_ref().map(tracing_utils::layer);
// Put it all together
tracing_subscriber::registry()
@@ -37,7 +40,7 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result
utils::logging::replace_panic_hook_with_tracing_panic_hook().forget();
Ok(())
Ok(provider)
}
/// Replace all newline characters with a special character to make it

View File

@@ -9,15 +9,20 @@ use crate::metrics::DB_MIGRATION_FAILED;
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
lakebase_mode: bool,
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
pub fn new(client: &'m mut Client, migrations: &'m [&'m str], lakebase_mode: bool) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
Self { client, migrations }
Self {
client,
migrations,
lakebase_mode,
}
}
/// Get the current value neon_migration.migration_id
@@ -130,8 +135,13 @@ impl<'m> MigrationRunner<'m> {
// ID is also the next index
let migration_id = (current_migration + 1) as i64;
let migration = self.migrations[current_migration];
let migration = if self.lakebase_mode {
migration.replace("neon_superuser", "databricks_superuser")
} else {
migration.to_string()
};
match Self::run_migration(self.client, migration_id, migration).await {
match Self::run_migration(self.client, migration_id, &migration).await {
Ok(_) => {
info!("Finished migration id={}", migration_id);
}

View File

@@ -11,6 +11,7 @@ use tracing::{Level, error, info, instrument, span};
use crate::compute::ComputeNode;
use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS};
const PG_DEFAULT_INIT_TIMEOUIT: Duration = Duration::from_secs(60);
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
/// Struct to store runtime state of the compute monitor thread.
@@ -352,13 +353,47 @@ impl ComputeMonitor {
// Hang on condition variable waiting until the compute status is `Running`.
fn wait_for_postgres_start(compute: &ComputeNode) {
let mut state = compute.state.lock().unwrap();
let pg_init_timeout = compute
.params
.pg_init_timeout
.unwrap_or(PG_DEFAULT_INIT_TIMEOUIT);
while state.status != ComputeStatus::Running {
info!("compute is not running, waiting before monitoring activity");
state = compute.state_changed.wait(state).unwrap();
if !compute.params.lakebase_mode {
state = compute.state_changed.wait(state).unwrap();
if state.status == ComputeStatus::Running {
break;
if state.status == ComputeStatus::Running {
break;
}
continue;
}
if state.pg_start_time.is_some()
&& Utc::now()
.signed_duration_since(state.pg_start_time.unwrap())
.to_std()
.unwrap_or_default()
> pg_init_timeout
{
// If Postgres isn't up and running with working PS/SK connections within POSTGRES_STARTUP_TIMEOUT, it is
// possible that we started Postgres with a wrong spec (so it is talking to the wrong PS/SK nodes). To prevent
// deadends we simply exit (panic) the compute node so it can restart with the latest spec.
//
// NB: We skip this check if we have not attempted to start PG yet (indicated by state.pg_start_up == None).
// This is to make sure the more appropriate errors are surfaced if we encounter issues before we even attempt
// to start PG (e.g., if we can't pull the spec, can't sync safekeepers, or can't get the basebackup).
error!(
"compute did not enter Running state in {} seconds, exiting",
pg_init_timeout.as_secs()
);
std::process::exit(1);
}
state = compute
.state_changed
.wait_timeout(state, Duration::from_secs(5))
.unwrap()
.0;
}
}

View File

@@ -11,7 +11,9 @@ use std::time::{Duration, Instant};
use anyhow::{Result, bail};
use compute_api::responses::TlsConfig;
use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
use compute_api::spec::{
Database, DatabricksSettings, GenericOption, GenericOptions, PgIdent, Role,
};
use futures::StreamExt;
use indexmap::IndexMap;
use ini::Ini;
@@ -184,6 +186,42 @@ impl DatabaseExt for Database {
}
}
pub trait DatabricksSettingsExt {
fn as_pg_settings(&self) -> String;
}
impl DatabricksSettingsExt for DatabricksSettings {
fn as_pg_settings(&self) -> String {
// Postgres GUCs rendered from DatabricksSettings
vec![
// ssl_ca_file
Some(format!(
"ssl_ca_file = '{}'",
self.pg_compute_tls_settings.ca_file
)),
// [Optional] databricks.workspace_url
Some(format!(
"databricks.workspace_url = '{}'",
&self.databricks_workspace_host
)),
// todo(vikas.jain): these are not required anymore as they are moved to static
// conf but keeping these to avoid image mismatch between hcc and pg.
// Once hcc and pg are in sync, we can remove these.
//
// databricks.enable_databricks_identity_login
Some("databricks.enable_databricks_identity_login = true".to_string()),
// databricks.enable_sql_restrictions
Some("databricks.enable_sql_restrictions = true".to_string()),
]
.into_iter()
// Removes `None`s
.flatten()
.collect::<Vec<String>>()
.join("\n")
+ "\n"
}
}
/// Generic trait used to provide quoting / encoding for strings used in the
/// Postgres SQL queries and DATABASE_URL.
pub trait Escaping {

View File

@@ -1,4 +1,6 @@
use std::fs::File;
use std::fs::{self, Permissions};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use anyhow::{Result, anyhow, bail};
@@ -133,10 +135,25 @@ pub fn get_config_from_control_plane(base_uri: &str, compute_id: &str) -> Result
}
/// Check `pg_hba.conf` and update if needed to allow external connections.
pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> {
pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> Result<()> {
// XXX: consider making it a part of config.json
let pghba_path = pgdata_path.join("pg_hba.conf");
// Update pg_hba to contains databricks specfic settings before adding neon settings
// PG uses the first record that matches to perform authentication, so we need to have
// our rules before the default ones from neon.
// See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html
if let Some(databricks_pg_hba) = databricks_pg_hba {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_hba),
)? {
info!("updated pg_hba.conf to include databricks_pg_hba.conf");
} else {
info!("pg_hba.conf already included databricks_pg_hba.conf");
}
}
if config::line_in_file(&pghba_path, PG_HBA_ALL_MD5)? {
info!("updated pg_hba.conf to allow external connections");
} else {
@@ -146,6 +163,59 @@ pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> {
Ok(())
}
/// Check `pg_ident.conf` and update if needed to allow databricks config.
pub fn update_pg_ident(pgdata_path: &Path, databricks_pg_ident: Option<&String>) -> Result<()> {
info!("checking pg_ident.conf");
let pghba_path = pgdata_path.join("pg_ident.conf");
// Update pg_ident to contains databricks specfic settings
if let Some(databricks_pg_ident) = databricks_pg_ident {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_ident),
)? {
info!("updated pg_ident.conf to include databricks_pg_ident.conf");
} else {
info!("pg_ident.conf already included databricks_pg_ident.conf");
}
}
Ok(())
}
/// Copy tls key_file and cert_file from k8s secret mount directory
/// to pgdata and set private key file permissions as expected by Postgres.
/// See this doc for expected permission <https://www.postgresql.org/docs/current/ssl-tcp.html>
/// K8s secrets mount on dblet does not honor permission and ownership
/// specified in the Volume or VolumeMount. So we need to explicitly copy the file and set the permissions.
pub fn copy_tls_certificates(
key_file: &String,
cert_file: &String,
pgdata_path: &Path,
) -> Result<()> {
let files = [cert_file, key_file];
for file in files.iter() {
let source = Path::new(file);
let dest = pgdata_path.join(source.file_name().unwrap());
if !dest.exists() {
std::fs::copy(source, &dest)?;
info!(
"Copying tls file: {} to {}",
&source.display(),
&dest.display()
);
}
if *file == key_file {
// Postgres requires private key to be readable only by the owner by having
// chmod 600 permissions.
let permissions = Permissions::from_mode(0o600);
fs::set_permissions(&dest, permissions)?;
info!("Setting permission on {}.", &dest.display());
}
}
Ok(())
}
/// Create a standby.signal file
pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
// XXX: consider making it a part of config.json
@@ -170,7 +240,11 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(
params: ComputeNodeParams,
client: &mut Client,
lakebase_mode: bool,
) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -234,7 +308,7 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
),
];
MigrationRunner::new(client, &migrations)
MigrationRunner::new(client, &migrations, lakebase_mode)
.run_migrations()
.await?;

View File

@@ -411,7 +411,8 @@ impl ComputeNode {
.map(|limit| match limit {
0..10 => limit,
10..30 => 10,
30.. => limit / 3,
30..300 => limit / 3,
300.. => 100,
})
// If we didn't find max_connections, default to 10 concurrent connections.
.unwrap_or(10)

View File

@@ -411,6 +411,12 @@ struct StorageControllerStartCmdArgs {
help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)"
)]
base_port: Option<u16>,
#[clap(
long,
help = "Whether the storage controller should handle pageserver-reported local disk loss events."
)]
handle_ps_local_disk_loss: Option<bool>,
}
#[derive(clap::Args)]
@@ -1800,6 +1806,7 @@ async fn handle_storage_controller(
instance_id: args.instance_id,
base_port: args.base_port,
start_timeout: args.start_timeout,
handle_ps_local_disk_loss: args.handle_ps_local_disk_loss,
};
if let Err(e) = svc.start(start_args).await {

View File

@@ -728,12 +728,9 @@ impl Endpoint {
// For the sake of backwards-compatibility, also fill in 'pageserver_connstring'
//
// XXX: I believe this is not really needed, except to make
// test_forward_compatibility happy.
//
// Use a closure so that we can conviniently return None in the middle of the
// loop.
let pageserver_connstring = (|| {
let pageserver_connstring: Option<String> = (|| {
let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() {
1
} else {
@@ -749,22 +746,24 @@ impl Endpoint {
.pageserver_conninfo
.shards
.get(&shard_index)
.expect(&format!(
"shard {} not found in pageserver_connection_info",
shard_index
));
.ok_or_else(|| {
anyhow!(
"shard {} not found in pageserver_connection_info",
shard_index
)
})?;
let pageserver = shard
.pageservers
.first()
.expect("must have at least one pageserver");
.ok_or(anyhow!("must have at least one pageserver"))?;
if let Some(libpq_url) = &pageserver.libpq_url {
connstrings.push(libpq_url.clone());
} else {
return None;
return Ok::<_, anyhow::Error>(None);
}
}
Some(connstrings.join(","))
})();
Ok(Some(connstrings.join(",")))
})()?;
// Create config file
let config = {

View File

@@ -56,6 +56,7 @@ pub struct NeonStorageControllerStartArgs {
pub instance_id: u8,
pub base_port: Option<u16>,
pub start_timeout: humantime::Duration,
pub handle_ps_local_disk_loss: Option<bool>,
}
impl NeonStorageControllerStartArgs {
@@ -64,6 +65,7 @@ impl NeonStorageControllerStartArgs {
instance_id: 1,
base_port: None,
start_timeout,
handle_ps_local_disk_loss: None,
}
}
}
@@ -669,6 +671,10 @@ impl StorageController {
println!("Starting storage controller at {scheme}://{host}:{listen_port}");
if start_args.handle_ps_local_disk_loss.unwrap_or_default() {
args.push("--handle-ps-local-disk-loss".to_string());
}
background_process::start_process(
COMMAND,
&instance_dir,

View File

@@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components.
# More documentation for the licenses section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html
[licenses]
version = 2
allow = [
"0BSD",
"Apache-2.0",

View File

@@ -233,7 +233,7 @@ mod tests {
.unwrap()
.as_millis();
use rand::Rng;
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let s3_config = remote_storage::S3Config {
bucket_name: var(REAL_S3_BUCKET).unwrap(),

View File

@@ -460,6 +460,32 @@ pub struct GenericOption {
pub vartype: String,
}
/// Postgres compute TLS settings.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct PgComputeTlsSettings {
// Absolute path to the certificate file for server-side TLS.
pub cert_file: String,
// Absolute path to the private key file for server-side TLS.
pub key_file: String,
// Absolute path to the certificate authority file for verifying client certificates.
pub ca_file: String,
}
/// Databricks specific options for compute instance.
/// This is used to store any other settings that needs to be propagate to Compute
/// but should not be persisted to ComputeSpec in the database.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct DatabricksSettings {
pub pg_compute_tls_settings: PgComputeTlsSettings,
// Absolute file path to databricks_pg_hba.conf file.
pub databricks_pg_hba: String,
// Absolute file path to databricks_pg_ident.conf file.
pub databricks_pg_ident: String,
// Hostname portion of the Databricks workspace URL of the endpoint, or empty string if not known.
// A valid hostname is required for the compute instance to support PAT logins.
pub databricks_workspace_host: String,
}
/// Optional collection of `GenericOption`'s. Type alias allows us to
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;

View File

@@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> {
IdempotencyKey {
now: Utc::now(),
node_id,
nonce: rand::thread_rng().gen_range(0..=9999),
nonce: rand::rng().random_range(0..=9999),
}
}

View File

@@ -41,7 +41,7 @@ impl NodeOs {
/// Generate a random number in range [0, max).
pub fn random(&self, max: u64) -> u64 {
self.internal.rng.lock().gen_range(0..max)
self.internal.rng.lock().random_range(0..max)
}
/// Append a new event to the world event log.

View File

@@ -32,10 +32,10 @@ impl Delay {
/// Generate a random delay in range [min, max]. Return None if the
/// message should be dropped.
pub fn delay(&self, rng: &mut StdRng) -> Option<u64> {
if rng.gen_bool(self.fail_prob) {
if rng.random_bool(self.fail_prob) {
return None;
}
Some(rng.gen_range(self.min..=self.max))
Some(rng.random_range(self.min..=self.max))
}
}

View File

@@ -69,7 +69,7 @@ impl World {
/// Create a new random number generator.
pub fn new_rng(&self) -> StdRng {
let mut rng = self.rng.lock();
StdRng::from_rng(rng.deref_mut()).unwrap()
StdRng::from_rng(rng.deref_mut())
}
/// Create a new node.

View File

@@ -17,5 +17,5 @@ procfs.workspace = true
measured-process.workspace = true
[dev-dependencies]
rand = "0.8"
rand_distr = "0.4.3"
rand.workspace = true
rand_distr = "0.5"

View File

@@ -260,7 +260,7 @@ mod tests {
#[test]
fn test_cardinality_small() {
let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap());
assert_eq!(actual, [46, 30, 32]);
assert!(51.3 < estimate[0] && estimate[0] < 51.4);
@@ -270,7 +270,7 @@ mod tests {
#[test]
fn test_cardinality_medium() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap());
assert_eq!(actual, [2529, 1618, 1629]);
assert!(2309.1 < estimate[0] && estimate[0] < 2309.2);
@@ -280,7 +280,8 @@ mod tests {
#[test]
fn test_cardinality_large() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap());
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap());
assert_eq!(actual, [129077, 79579, 79630]);
assert!(126067.2 < estimate[0] && estimate[0] < 126067.3);
@@ -290,7 +291,7 @@ mod tests {
#[test]
fn test_cardinality_small2() {
let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap());
assert_eq!(actual, [92, 58, 60]);
assert!(116.1 < estimate[0] && estimate[0] < 116.2);
@@ -300,7 +301,7 @@ mod tests {
#[test]
fn test_cardinality_medium2() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap());
assert_eq!(actual, [8201, 5131, 5051]);
assert!(6846.4 < estimate[0] && estimate[0] < 6846.5);
@@ -310,7 +311,8 @@ mod tests {
#[test]
fn test_cardinality_large2() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap());
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap());
assert_eq!(actual, [777847, 482069, 482246]);
assert!(699437.4 < estimate[0] && estimate[0] < 699437.5);

View File

@@ -188,14 +188,14 @@ fn real_benchs(c: &mut Criterion) {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
let entry = writer.entry(key);
std::hint::black_box(match entry {
match entry {
Entry::Occupied(mut e) => {
e.insert(val);
std::hint::black_box(e.insert(val));
}
Entry::Vacant(e) => {
_ = e.insert(val).unwrap();
let _ = std::hint::black_box(e.insert(val).unwrap());
}
})
}
}
},
BatchSize::SmallInput,
@@ -220,12 +220,12 @@ fn real_benchs(c: &mut Criterion) {
let ideal_filled = 100_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher::default();
let hasher = rustc_hash::FxBuildHasher;
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(&k),
|(k, _)| hasher.hash_one(k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
@@ -234,7 +234,7 @@ fn real_benchs(c: &mut Criterion) {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(&k)
hasher.hash_one(k)
});
}
b.iter(|| unsafe {
@@ -282,12 +282,12 @@ fn real_benchs(c: &mut Criterion) {
let size = 125_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher::default();
let hasher = rustc_hash::FxBuildHasher;
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(&k),
|(k, _)| hasher.hash_one(k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
@@ -296,7 +296,7 @@ fn real_benchs(c: &mut Criterion) {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(&k)
hasher.hash_one(k)
});
}
b.iter(|| unsafe {

View File

@@ -61,6 +61,10 @@ impl<K: 'static, V: 'static> OccupiedEntry<'_, K, V> {
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
///
/// # Panics
/// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means
/// the entry was obtained via calling something like [`super::HashMapAccess::entry_at_bucket`].
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {

View File

@@ -1,14 +0,0 @@
[package]
name = "neonart"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
crossbeam-utils.workspace = true
spin.workspace = true
tracing.workspace = true
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"

View File

@@ -1,599 +0,0 @@
mod lock_and_version;
pub(crate) mod node_ptr;
mod node_ref;
use std::vec::Vec;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::algorithm::node_ptr::MAX_PREFIX_LEN;
use crate::algorithm::node_ref::{NewNodeRef, NodeRef, ReadLockedNodeRef, WriteLockedNodeRef};
use crate::allocator::OutOfMemoryError;
use crate::TreeWriteGuard;
use crate::UpdateAction;
use crate::allocator::ArtAllocator;
use crate::epoch::EpochPin;
use crate::{Key, Value};
pub(crate) type RootPtr<V> = node_ptr::NodePtr<V>;
#[derive(Debug)]
pub enum ArtError {
ConcurrentUpdate, // need to retry
OutOfMemory,
}
impl From<ConcurrentUpdateError> for ArtError {
fn from(_: ConcurrentUpdateError) -> ArtError {
ArtError::ConcurrentUpdate
}
}
impl From<OutOfMemoryError> for ArtError {
fn from(_: OutOfMemoryError) -> ArtError {
ArtError::OutOfMemory
}
}
pub fn new_root<V: Value>(
allocator: &impl ArtAllocator<V>,
) -> Result<RootPtr<V>, OutOfMemoryError> {
node_ptr::new_root(allocator)
}
pub(crate) fn search<'e, K: Key, V: Value>(
key: &K,
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<&'e V> {
loop {
let root_ref = NodeRef::from_root_ptr(root);
if let Ok(result) = lookup_recurse(key.as_bytes(), root_ref, None, epoch_pin) {
break result;
}
// retry
}
}
pub(crate) fn iter_next<'e, V: Value>(
key: &[u8],
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<(Vec<u8>, &'e V)> {
loop {
let mut path = Vec::new();
let root_ref = NodeRef::from_root_ptr(root);
match next_recurse(key, &mut path, root_ref, epoch_pin) {
Ok(Some(v)) => {
assert_eq!(path.len(), key.len());
break Some((path, v));
}
Ok(None) => break None,
Err(ConcurrentUpdateError()) => {
// retry
continue;
}
}
}
}
pub(crate) fn update_fn<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &K,
value_fn: F,
root: RootPtr<V>,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let value_fn_cell = std::cell::Cell::new(Some(value_fn));
loop {
let root_ref = NodeRef::from_root_ptr(root);
let this_value_fn = |arg: Option<&V>| value_fn_cell.take().unwrap()(arg);
let key_bytes = key.as_bytes();
match update_recurse(
key_bytes,
this_value_fn,
root_ref,
None,
None,
guard,
0,
key_bytes,
) {
Ok(()) => break Ok(()),
Err(ArtError::ConcurrentUpdate) => {
continue; // retry
}
Err(ArtError::OutOfMemory) => break Err(OutOfMemoryError()),
}
}
}
// Error means you must retry.
//
// This corresponds to the 'lookupOpt' function in the paper
#[allow(clippy::only_used_in_recursion)]
fn lookup_recurse<'e, V: Value>(
key: &[u8],
node: NodeRef<'e, V>,
parent: Option<ReadLockedNodeRef<V>>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
if let Some(parent) = parent {
parent.read_unlock_or_restart()?;
}
// check if the prefix matches, may increment level
let prefix_len = if let Some(prefix_len) = rnode.prefix_matches(key) {
prefix_len
} else {
rnode.read_unlock_or_restart()?;
return Ok(None);
};
if rnode.is_leaf() {
assert_eq!(key.len(), prefix_len);
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let key = &key[prefix_len..];
// find child (or leaf value)
let next_node = rnode.find_child_or_restart(key[0])?;
match next_node {
None => Ok(None), // key not found
Some(child) => lookup_recurse(&key[1..], child, Some(rnode), epoch_pin),
}
}
#[allow(clippy::only_used_in_recursion)]
fn next_recurse<'e, V: Value>(
min_key: &[u8],
path: &mut Vec<u8>,
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
let prefix = rnode.get_prefix();
if !prefix.is_empty() {
path.extend_from_slice(prefix);
}
use std::cmp::Ordering;
let comparison = path.as_slice().cmp(&min_key[0..path.len()]);
if comparison == Ordering::Less {
rnode.read_unlock_or_restart()?;
return Ok(None);
}
if rnode.is_leaf() {
assert_eq!(path.len(), min_key.len());
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let mut min_key_byte = match comparison {
Ordering::Less => unreachable!(), // checked this above already
Ordering::Equal => min_key[path.len()],
Ordering::Greater => 0,
};
loop {
match rnode.find_next_child_or_restart(min_key_byte)? {
None => {
return Ok(None);
}
Some((key_byte, child_ref)) => {
let path_len = path.len();
path.push(key_byte);
let result = next_recurse(min_key, path, child_ref, epoch_pin)?;
if result.is_some() {
return Ok(result);
}
if key_byte == u8::MAX {
return Ok(None);
}
path.truncate(path_len);
min_key_byte = key_byte + 1;
}
}
}
}
// This corresponds to the 'insertOpt' function in the paper
#[allow(clippy::only_used_in_recursion)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn update_recurse<'e, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &[u8],
value_fn: F,
node: NodeRef<'e, V>,
rparent: Option<(ReadLockedNodeRef<V>, u8)>,
rgrandparent: Option<(ReadLockedNodeRef<V>, u8)>,
guard: &'_ mut TreeWriteGuard<'e, K, V, A>,
level: usize,
orig_key: &[u8],
) -> Result<(), ArtError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let rnode = node.read_lock_or_restart()?;
let prefix_match_len = rnode.prefix_matches(key);
if prefix_match_len.is_none() {
let (rparent, parent_key) = rparent.expect("direct children of the root have no prefix");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_split_prefix(key, new_value, &mut wnode, &mut wparent, parent_key, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
}
wnode.write_unlock();
wparent.write_unlock();
return Ok(());
}
let prefix_match_len = prefix_match_len.unwrap();
let key = &key[prefix_match_len..];
let level = level + prefix_match_len;
if rnode.is_leaf() {
assert_eq!(key.len(), 0);
let (rparent, parent_key) = rparent.expect("root cannot be leaf");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
// safety: Now that we have acquired the write lock, we have exclusive access to the
// value. XXX: There might be concurrent reads though?
let value_mut = wnode.get_leaf_value_mut();
match value_fn(Some(value_mut)) {
UpdateAction::Nothing => {
wparent.write_unlock();
wnode.write_unlock();
}
UpdateAction::Insert(_) => panic!("cannot insert over existing value"),
UpdateAction::Remove => {
guard.remember_obsolete_node(wnode.as_ptr());
wparent.delete_child(parent_key);
wnode.write_unlock_obsolete();
if let Some(rgrandparent) = rgrandparent {
// FIXME: Ignore concurrency error. It doesn't lead to
// corruption, but it means we might leak something. Until
// another update cleans it up.
let _ = cleanup_parent(wparent, rgrandparent, guard);
}
}
}
return Ok(());
}
let next_node = rnode.find_child_or_restart(key[0])?;
if next_node.is_none() {
if rnode.is_full() {
let (rparent, parent_key) = rparent.expect("root node cannot become full");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {
wnode.write_unlock();
wparent.write_unlock();
}
UpdateAction::Insert(new_value) => {
insert_and_grow(key, new_value, wnode, &mut wparent, parent_key, guard)?;
wparent.write_unlock();
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
} else {
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
if let Some((rparent, _)) = rparent {
rparent.read_unlock_or_restart()?;
}
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_to_node(&mut wnode, key, new_value, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
wnode.write_unlock();
}
Ok(())
} else {
let next_child = next_node.unwrap(); // checked above it's not None
if let Some((ref rparent, _)) = rparent {
rparent.check_or_restart()?;
}
// recurse to next level
update_recurse(
&key[1..],
value_fn,
next_child,
Some((rnode, key[0])),
rparent,
guard,
level + 1,
orig_key,
)
}
}
#[derive(Clone)]
enum PathElement {
Prefix(Vec<u8>),
KeyByte(u8),
}
impl std::fmt::Debug for PathElement {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
PathElement::Prefix(prefix) => write!(fmt, "{prefix:?}"),
PathElement::KeyByte(key_byte) => write!(fmt, "{key_byte}"),
}
}
}
pub(crate) fn dump_tree<V: Value + std::fmt::Debug>(
root: RootPtr<V>,
epoch_pin: &'_ EpochPin,
dst: &mut dyn std::io::Write,
) {
let root_ref = NodeRef::from_root_ptr(root);
let _ = dump_recurse(&[], root_ref, epoch_pin, 0, dst);
}
// TODO: return an Err if writeln!() returns error, instead of unwrapping
#[allow(clippy::only_used_in_recursion)]
fn dump_recurse<'e, V: Value + std::fmt::Debug>(
path: &[PathElement],
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
level: usize,
dst: &mut dyn std::io::Write,
) -> Result<(), ConcurrentUpdateError> {
let indent = str::repeat(" ", level);
let rnode = node.read_lock_or_restart()?;
let mut path = Vec::from(path);
let prefix = rnode.get_prefix();
if !prefix.is_empty() {
path.push(PathElement::Prefix(Vec::from(prefix)));
}
if rnode.is_leaf() {
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let val = unsafe { vptr.as_ref().unwrap() };
writeln!(dst, "{indent} {path:?}: {val:?}").unwrap();
return Ok(());
}
for key_byte in 0..=u8::MAX {
match rnode.find_child_or_restart(key_byte)? {
None => continue,
Some(child_ref) => {
let rchild = child_ref.read_lock_or_restart()?;
writeln!(
dst,
"{} {:?}, {}: prefix {:?}",
indent,
&path,
key_byte,
rchild.get_prefix()
)
.unwrap();
let mut child_path = path.clone();
child_path.push(PathElement::KeyByte(key_byte));
dump_recurse(&child_path, child_ref, epoch_pin, level + 1, dst)?;
}
}
}
Ok(())
}
///```text
/// [fooba]r -> value
///
/// [foo]b -> [a]r -> value
/// e -> [ls]e -> value
///```
fn insert_split_prefix<K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
node: &mut WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key: u8,
guard: &'_ TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let old_node = node;
let old_prefix = old_node.get_prefix();
let common_prefix_len = common_prefix(key, old_prefix);
// Allocate a node for the new value.
let new_value_node = allocate_node_for_value(
&key[common_prefix_len + 1..],
value,
guard.tree_writer.allocator,
)?;
// Allocate a new internal node with the common prefix
// FIXME: deallocate 'new_value_node' on OOM
let mut prefix_node =
node_ref::new_internal(&key[..common_prefix_len], guard.tree_writer.allocator)?;
// Add the old node and the new nodes to the new internal node
prefix_node.insert_old_child(old_prefix[common_prefix_len], old_node);
prefix_node.insert_new_child(key[common_prefix_len], new_value_node);
// Modify the prefix of the old child in place
old_node.truncate_prefix(old_prefix.len() - common_prefix_len - 1);
// replace the pointer in the parent
parent.replace_child(parent_key, prefix_node.into_ptr());
Ok(())
}
fn insert_to_node<K: Key, V: Value, A: ArtAllocator<V>>(
wnode: &mut WriteLockedNodeRef<V>,
key: &[u8],
value: V,
guard: &'_ TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
wnode.insert_child(key[0], value_child.into_ptr());
Ok(())
}
// On entry: 'parent' and 'node' are locked
fn insert_and_grow<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
wnode: WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key_byte: u8,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let mut bigger_node = wnode.grow(guard.tree_writer.allocator)?;
// FIXME: deallocate 'bigger_node' on OOM
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
bigger_node.insert_new_child(key[0], value_child);
// Replace the pointer in the parent
parent.replace_child(parent_key_byte, bigger_node.into_ptr());
guard.remember_obsolete_node(wnode.as_ptr());
wnode.write_unlock_obsolete();
Ok(())
}
fn cleanup_parent<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
wparent: WriteLockedNodeRef<V>,
rgrandparent: (ReadLockedNodeRef<V>, u8),
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let (rgrandparent, grandparent_key_byte) = rgrandparent;
// If the parent becomes completely empty after the deletion, remove the parent from the
// grandparent. (This case is possible because we reserve only 8 bytes for the prefix.)
// TODO: not implemented.
// If the parent has only one child, replace the parent with the remaining child. (This is not
// possible if the child's prefix field cannot absorb the parent's)
if wparent.num_children() == 1 {
// Try to lock the remaining child. This can fail if the child is updated
// concurrently.
let (key_byte, remaining_child) = wparent.find_remaining_child();
let mut wremaining_child = remaining_child.write_lock_or_restart()?;
if 1 + wremaining_child.get_prefix().len() + wparent.get_prefix().len() <= MAX_PREFIX_LEN {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
// Ok, we have locked the leaf, the parent, the grandparent, and the parent's only
// remaining leaf. Proceed with the updates.
// Update the prefix on the remaining leaf
wremaining_child.prepend_prefix(wparent.get_prefix(), key_byte);
// Replace the pointer in the grandparent to point directly to the remaining leaf
wgrandparent.replace_child(grandparent_key_byte, wremaining_child.as_ptr());
// Mark the parent as deleted.
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
}
// If the parent's children would fit on a smaller node type after the deletion, replace it with
// a smaller node.
if wparent.can_shrink() {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
let smaller_node = wparent.shrink(guard.tree_writer.allocator)?;
// Replace the pointer in the grandparent
wgrandparent.replace_child(grandparent_key_byte, smaller_node.into_ptr());
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
// nothing to do
wparent.write_unlock();
Ok(())
}
// Allocate a new leaf node to hold 'value'. If the key is long, we
// may need to allocate new internal nodes to hold it too
fn allocate_node_for_value<'a, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError> {
let mut prefix_off = key.len().saturating_sub(MAX_PREFIX_LEN);
let leaf_node = node_ref::new_leaf(&key[prefix_off..key.len()], value, allocator)?;
let mut node = leaf_node;
while prefix_off > 0 {
// Need another internal node
let remain_prefix = &key[0..prefix_off];
prefix_off = remain_prefix.len().saturating_sub(MAX_PREFIX_LEN + 1);
let mut internal_node = node_ref::new_internal(
&remain_prefix[prefix_off..remain_prefix.len() - 1],
allocator,
)?;
internal_node.insert_new_child(*remain_prefix.last().unwrap(), node);
node = internal_node;
}
Ok(node)
}
fn common_prefix(a: &[u8], b: &[u8]) -> usize {
for i in 0..MAX_PREFIX_LEN {
if a[i] != b[i] {
return i;
}
}
panic!("prefixes are equal");
}

View File

@@ -1,117 +0,0 @@
//! Each node in the tree has contains one atomic word that stores three things:
//!
//! Bit 0: set if the node is "obsolete". An obsolete node has been removed from the tree,
//! but might still be accessed by concurrent readers until the epoch expires.
//! Bit 1: set if the node is currently write-locked. Used as a spinlock.
//! Bits 2-63: Version number, incremented every time the node is modified.
//!
//! AtomicLockAndVersion represents that.
use std::sync::atomic::{AtomicU64, Ordering};
pub(crate) struct ConcurrentUpdateError();
pub(crate) struct AtomicLockAndVersion {
inner: AtomicU64,
}
impl AtomicLockAndVersion {
pub(crate) fn new() -> AtomicLockAndVersion {
AtomicLockAndVersion {
inner: AtomicU64::new(0),
}
}
}
impl AtomicLockAndVersion {
pub(crate) fn read_lock_or_restart(&self) -> Result<u64, ConcurrentUpdateError> {
let version = self.await_node_unlocked();
if is_obsolete(version) {
return Err(ConcurrentUpdateError());
}
Ok(version)
}
pub(crate) fn check_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
self.read_unlock_or_restart(version)
}
pub(crate) fn read_unlock_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
if self.inner.load(Ordering::Acquire) != version {
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn upgrade_to_write_lock_or_restart(
&self,
version: u64,
) -> Result<(), ConcurrentUpdateError> {
if self
.inner
.compare_exchange(
version,
set_locked_bit(version),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_lock_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
let old = self.inner.load(Ordering::Relaxed);
if is_obsolete(old) || is_locked(old) {
return Err(ConcurrentUpdateError());
}
if self
.inner
.compare_exchange(
old,
set_locked_bit(old),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_unlock(&self) {
// reset locked bit and overflow into version
self.inner.fetch_add(2, Ordering::Release);
}
pub(crate) fn write_unlock_obsolete(&self) {
// set obsolete, reset locked, overflow into version
self.inner.fetch_add(3, Ordering::Release);
}
// Helper functions
fn await_node_unlocked(&self) -> u64 {
let mut version = self.inner.load(Ordering::Acquire);
while is_locked(version) {
// spinlock
std::thread::yield_now();
version = self.inner.load(Ordering::Acquire)
}
version
}
}
fn set_locked_bit(version: u64) -> u64 {
version + 2
}
fn is_obsolete(version: u64) -> bool {
(version & 1) == 1
}
fn is_locked(version: u64) -> bool {
(version & 2) == 2
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,349 +0,0 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use super::node_ptr;
use super::node_ptr::NodePtr;
use crate::EpochPin;
use crate::Value;
use crate::algorithm::lock_and_version::AtomicLockAndVersion;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::allocator::ArtAllocator;
use crate::allocator::OutOfMemoryError;
pub struct NodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V> Debug for NodeRef<'e, V> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.ptr)
}
}
impl<'e, V: Value> NodeRef<'e, V> {
pub(crate) fn from_root_ptr(root_ptr: NodePtr<V>) -> NodeRef<'e, V> {
NodeRef {
ptr: root_ptr,
phantom: PhantomData,
}
}
pub(crate) fn read_lock_or_restart(
&self,
) -> Result<ReadLockedNodeRef<'e, V>, ConcurrentUpdateError> {
let version = self.lockword().read_lock_or_restart()?;
Ok(ReadLockedNodeRef {
ptr: self.ptr,
version,
phantom: self.phantom,
})
}
pub(crate) fn write_lock_or_restart(
&self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.lockword().write_lock_or_restart()?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
fn lockword(&self) -> &AtomicLockAndVersion {
self.ptr.lockword()
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct ReadLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
version: u64,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> ReadLockedNodeRef<'e, V> {
pub(crate) fn is_leaf(&self) -> bool {
self.ptr.is_leaf()
}
pub(crate) fn is_full(&self) -> bool {
self.ptr.is_full()
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
/// Note: because we're only holding a read lock, the prefix can change concurrently.
/// You must be prepared to restart, if read_unlock() returns error later.
///
/// Returns the length of the prefix, or None if it's not a match
pub(crate) fn prefix_matches(&self, key: &[u8]) -> Option<usize> {
self.ptr.prefix_matches(key)
}
pub(crate) fn find_child_or_restart(
&self,
key_byte: u8,
) -> Result<Option<NodeRef<'e, V>>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_child(key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some(child_ptr) => Ok(Some(NodeRef {
ptr: child_ptr,
phantom: self.phantom,
})),
}
}
pub(crate) fn find_next_child_or_restart(
&self,
min_key_byte: u8,
) -> Result<Option<(u8, NodeRef<'e, V>)>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_next_child(min_key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some((k, child_ptr)) => Ok(Some((
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
))),
}
}
pub(crate) fn get_leaf_value_ptr(&self) -> Result<*const V, ConcurrentUpdateError> {
let result = self.ptr.get_leaf_value();
self.ptr.lockword().check_or_restart(self.version)?;
// Extend the lifetime.
let result = std::ptr::from_ref(result);
Ok(result)
}
pub(crate) fn upgrade_to_write_lock_or_restart(
self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.ptr
.lockword()
.upgrade_to_write_lock_or_restart(self.version)?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
pub(crate) fn read_unlock_or_restart(self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
pub(crate) fn check_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct WriteLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> WriteLockedNodeRef<'e, V> {
pub(crate) fn can_shrink(&self) -> bool {
self.ptr.can_shrink()
}
pub(crate) fn num_children(&self) -> usize {
self.ptr.num_children()
}
pub(crate) fn write_unlock(mut self) {
self.ptr.lockword().write_unlock();
self.ptr = NodePtr::null();
}
pub(crate) fn write_unlock_obsolete(mut self) {
self.ptr.lockword().write_unlock_obsolete();
self.ptr = NodePtr::null();
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
pub(crate) fn truncate_prefix(&mut self, new_prefix_len: usize) {
self.ptr.truncate_prefix(new_prefix_len)
}
pub(crate) fn prepend_prefix(&mut self, prefix: &[u8], prefix_byte: u8) {
self.ptr.prepend_prefix(prefix, prefix_byte)
}
pub(crate) fn insert_child(&mut self, key_byte: u8, child: NodePtr<V>) {
self.ptr.insert_child(key_byte, child)
}
pub(crate) fn get_leaf_value_mut(&mut self) -> &mut V {
self.ptr.get_leaf_value_mut()
}
pub(crate) fn grow<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.grow(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn shrink<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.shrink(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn as_ptr(&self) -> NodePtr<V> {
self.ptr
}
pub(crate) fn replace_child(&mut self, key_byte: u8, replacement: NodePtr<V>) {
self.ptr.replace_child(key_byte, replacement);
}
pub(crate) fn delete_child(&mut self, key_byte: u8) {
self.ptr.delete_child(key_byte);
}
pub(crate) fn find_remaining_child(&self) -> (u8, NodeRef<'e, V>) {
assert_eq!(self.num_children(), 1);
let child_or_value = self.ptr.find_next_child(0);
match child_or_value {
None => panic!("could not find only child in node"),
Some((k, child_ptr)) => (
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
),
}
}
}
impl<'e, V> Drop for WriteLockedNodeRef<'e, V> {
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.lockword().write_unlock();
}
}
}
pub(crate) struct NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
ptr: NodePtr<V>,
allocator: &'a A,
extra_nodes: Vec<NodePtr<V>>,
}
impl<'a, V, A> NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
pub(crate) fn insert_old_child(&mut self, key_byte: u8, child: &WriteLockedNodeRef<V>) {
self.ptr.insert_child(key_byte, child.as_ptr())
}
pub(crate) fn into_ptr(mut self) -> NodePtr<V> {
let ptr = self.ptr;
self.ptr = NodePtr::null();
ptr
}
pub(crate) fn insert_new_child(&mut self, key_byte: u8, child: NewNodeRef<'a, V, A>) {
let child_ptr = child.into_ptr();
self.ptr.insert_child(key_byte, child_ptr);
self.extra_nodes.push(child_ptr);
}
}
impl<'a, V, A> Drop for NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
/// This drop implementation deallocates the newly allocated node, if into_ptr() was not called.
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.deallocate(self.allocator);
for p in self.extra_nodes.iter() {
p.deallocate(self.allocator);
}
}
}
}
pub(crate) fn new_internal<'a, V, A>(
prefix: &[u8],
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_internal(prefix, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn new_leaf<'a, V, A>(
prefix: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_leaf(prefix, value, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}

View File

@@ -1,156 +0,0 @@
pub mod block;
mod multislab;
mod slab;
pub mod r#static;
use std::alloc::Layout;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::sync::atomic::Ordering;
use crate::allocator::multislab::MultiSlabAllocator;
use crate::allocator::r#static::alloc_from_slice;
use spin;
use crate::Tree;
pub use crate::algorithm::node_ptr::{
NodeInternal4, NodeInternal16, NodeInternal48, NodeInternal256, NodeLeaf,
};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub trait ArtAllocator<V: crate::Value> {
fn alloc_tree(&self) -> *mut Tree<V>;
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V>;
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V>;
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V>;
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V>;
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V>;
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>);
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>);
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>);
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>);
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>);
}
pub struct ArtMultiSlabAllocator<'t, V>
where
V: crate::Value,
{
tree_area: spin::Mutex<Option<&'t mut MaybeUninit<Tree<V>>>>,
pub(crate) inner: MultiSlabAllocator<'t, 5>,
phantom_val: PhantomData<V>,
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
const LAYOUTS: [Layout; 5] = [
Layout::new::<NodeInternal4<V>>(),
Layout::new::<NodeInternal16<V>>(),
Layout::new::<NodeInternal48<V>>(),
Layout::new::<NodeInternal256<V>>(),
Layout::new::<NodeLeaf<V>>(),
];
pub fn new(area: &'t mut [MaybeUninit<u8>]) -> &'t mut ArtMultiSlabAllocator<'t, V> {
let (allocator_area, remain) = alloc_from_slice::<ArtMultiSlabAllocator<V>>(area);
let (tree_area, remain) = alloc_from_slice::<Tree<V>>(remain);
allocator_area.write(ArtMultiSlabAllocator {
tree_area: spin::Mutex::new(Some(tree_area)),
inner: MultiSlabAllocator::new(remain, &Self::LAYOUTS),
phantom_val: PhantomData,
})
}
}
impl<'t, V: crate::Value> ArtAllocator<V> for ArtMultiSlabAllocator<'t, V> {
fn alloc_tree(&self) -> *mut Tree<V> {
let mut t = self.tree_area.lock();
if let Some(tree_area) = t.take() {
return tree_area.as_mut_ptr().cast();
}
panic!("cannot allocate more than one tree");
}
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V> {
self.inner.alloc_slab(0).cast()
}
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V> {
self.inner.alloc_slab(1).cast()
}
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V> {
self.inner.alloc_slab(2).cast()
}
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V> {
self.inner.alloc_slab(3).cast()
}
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V> {
self.inner.alloc_slab(4).cast()
}
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>) {
self.inner.dealloc_slab(0, ptr.cast())
}
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>) {
self.inner.dealloc_slab(1, ptr.cast())
}
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>) {
self.inner.dealloc_slab(2, ptr.cast())
}
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>) {
self.inner.dealloc_slab(3, ptr.cast())
}
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>) {
self.inner.dealloc_slab(4, ptr.cast())
}
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
pub(crate) fn get_statistics(&self) -> ArtMultiSlabStats {
ArtMultiSlabStats {
num_internal4: self.inner.slab_descs[0]
.num_allocated
.load(Ordering::Relaxed),
num_internal16: self.inner.slab_descs[1]
.num_allocated
.load(Ordering::Relaxed),
num_internal48: self.inner.slab_descs[2]
.num_allocated
.load(Ordering::Relaxed),
num_internal256: self.inner.slab_descs[3]
.num_allocated
.load(Ordering::Relaxed),
num_leaf: self.inner.slab_descs[4]
.num_allocated
.load(Ordering::Relaxed),
num_blocks_internal4: self.inner.slab_descs[0].num_blocks.load(Ordering::Relaxed),
num_blocks_internal16: self.inner.slab_descs[1].num_blocks.load(Ordering::Relaxed),
num_blocks_internal48: self.inner.slab_descs[2].num_blocks.load(Ordering::Relaxed),
num_blocks_internal256: self.inner.slab_descs[3].num_blocks.load(Ordering::Relaxed),
num_blocks_leaf: self.inner.slab_descs[4].num_blocks.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Debug)]
pub struct ArtMultiSlabStats {
pub num_internal4: u64,
pub num_internal16: u64,
pub num_internal48: u64,
pub num_internal256: u64,
pub num_leaf: u64,
pub num_blocks_internal4: u64,
pub num_blocks_internal16: u64,
pub num_blocks_internal48: u64,
pub num_blocks_internal256: u64,
pub num_blocks_leaf: u64,
}

View File

@@ -1,191 +0,0 @@
//! Simple allocator of fixed-size blocks
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicU64, Ordering};
use spin;
pub const BLOCK_SIZE: usize = 16 * 1024;
const INVALID_BLOCK: u64 = u64::MAX;
pub(crate) struct BlockAllocator<'t> {
blocks_ptr: &'t [MaybeUninit<u8>],
num_blocks: u64,
num_initialized: AtomicU64,
freelist_head: spin::Mutex<u64>,
}
struct FreeListBlock {
inner: spin::Mutex<FreeListBlockInner>,
}
struct FreeListBlockInner {
next: u64,
num_free_blocks: u64,
free_blocks: [u64; 100], // FIXME: fill the rest of the block
}
impl<'t> BlockAllocator<'t> {
pub(crate) fn new(area: &'t mut [MaybeUninit<u8>]) -> Self {
// Use all the space for the blocks
let padding = area.as_ptr().align_offset(BLOCK_SIZE);
let remain = &mut area[padding..];
let num_blocks = (remain.len() / BLOCK_SIZE) as u64;
BlockAllocator {
blocks_ptr: remain,
num_blocks,
num_initialized: AtomicU64::new(0),
freelist_head: spin::Mutex::new(INVALID_BLOCK),
}
}
/// safety: you must hold a lock on the pointer to this block, otherwise it might get
/// reused for another kind of block
fn read_freelist_block(&self, blkno: u64) -> &FreeListBlock {
let ptr: *const FreeListBlock = self.get_block_ptr(blkno).cast();
unsafe { ptr.as_ref().unwrap() }
}
fn get_block_ptr(&self, blkno: u64) -> *mut u8 {
assert!(blkno < self.num_blocks);
unsafe {
self.blocks_ptr
.as_ptr()
.byte_offset(blkno as isize * BLOCK_SIZE as isize)
}
.cast_mut()
.cast()
}
#[allow(clippy::mut_from_ref)]
pub(crate) fn alloc_block(&self) -> &mut [MaybeUninit<u8>] {
// FIXME: handle OOM
let blkno = self.alloc_block_internal();
if blkno == INVALID_BLOCK {
panic!("out of memory");
}
let ptr: *mut MaybeUninit<u8> = self.get_block_ptr(blkno).cast();
unsafe { std::slice::from_raw_parts_mut(ptr, BLOCK_SIZE) }
}
fn alloc_block_internal(&self) -> u64 {
// check the free list.
{
let mut freelist_head = self.freelist_head.lock();
if *freelist_head != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(*freelist_head);
// acquire lock on the freelist block before releasing the lock on the parent (i.e. lock coupling)
let mut g = freelist_block.inner.lock();
if g.num_free_blocks > 0 {
g.num_free_blocks -= 1;
let result = g.free_blocks[g.num_free_blocks as usize];
return result;
} else {
// consume the freelist block itself
let result = *freelist_head;
*freelist_head = g.next;
// This freelist block is now unlinked and can be repurposed
drop(g);
return result;
}
}
}
// If there are some blocks left that we've never used, pick next such block
let mut next_uninitialized = self.num_initialized.load(Ordering::Relaxed);
while next_uninitialized < self.num_blocks {
match self.num_initialized.compare_exchange(
next_uninitialized,
next_uninitialized + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
return next_uninitialized;
}
Err(old) => {
next_uninitialized = old;
continue;
}
}
}
// out of blocks
INVALID_BLOCK
}
// TODO: this is currently unused. The slab allocator never releases blocks
#[allow(dead_code)]
pub(crate) fn release_block(&self, block_ptr: *mut u8) {
let blockno = unsafe { block_ptr.byte_offset_from(self.blocks_ptr) / BLOCK_SIZE as isize };
self.release_block_internal(blockno as u64);
}
fn release_block_internal(&self, blockno: u64) {
let mut freelist_head = self.freelist_head.lock();
if *freelist_head != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(*freelist_head);
// acquire lock on the freelist block before releasing the lock on the parent (i.e. lock coupling)
let mut g = freelist_block.inner.lock();
let num_free_blocks = g.num_free_blocks;
if num_free_blocks < g.free_blocks.len() as u64 {
g.free_blocks[num_free_blocks as usize] = blockno;
g.num_free_blocks += 1;
return;
}
}
// Convert the block into a new freelist block
let block_ptr: *mut FreeListBlock = self.get_block_ptr(blockno).cast();
let init = FreeListBlock {
inner: spin::Mutex::new(FreeListBlockInner {
next: *freelist_head,
num_free_blocks: 0,
free_blocks: [INVALID_BLOCK; 100],
}),
};
unsafe { (*block_ptr) = init };
*freelist_head = blockno;
}
// for debugging
pub(crate) fn get_statistics(&self) -> BlockAllocatorStats {
let mut num_free_blocks = 0;
let mut _prev_lock = None;
let head_lock = self.freelist_head.lock();
let mut next_blk = *head_lock;
let mut _head_lock = Some(head_lock);
while next_blk != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(next_blk);
let lock = freelist_block.inner.lock();
num_free_blocks += lock.num_free_blocks;
next_blk = lock.next;
_prev_lock = Some(lock); // hold the lock until we've read the next block
_head_lock = None;
}
BlockAllocatorStats {
num_blocks: self.num_blocks,
num_initialized: self.num_initialized.load(Ordering::Relaxed),
num_free_blocks,
}
}
}
#[derive(Clone, Debug)]
pub struct BlockAllocatorStats {
pub num_blocks: u64,
pub num_initialized: u64,
pub num_free_blocks: u64,
}

View File

@@ -1,33 +0,0 @@
use std::alloc::Layout;
use std::mem::MaybeUninit;
use crate::allocator::block::BlockAllocator;
use crate::allocator::slab::SlabDesc;
pub struct MultiSlabAllocator<'t, const N: usize> {
pub(crate) block_allocator: BlockAllocator<'t>,
pub(crate) slab_descs: [SlabDesc; N],
}
impl<'t, const N: usize> MultiSlabAllocator<'t, N> {
pub(crate) fn new(
area: &'t mut [MaybeUninit<u8>],
layouts: &[Layout; N],
) -> MultiSlabAllocator<'t, N> {
let block_allocator = BlockAllocator::new(area);
MultiSlabAllocator {
block_allocator,
slab_descs: std::array::from_fn(|i| SlabDesc::new(&layouts[i])),
}
}
pub(crate) fn alloc_slab(&self, slab_idx: usize) -> *mut u8 {
self.slab_descs[slab_idx].alloc_chunk(&self.block_allocator)
}
pub(crate) fn dealloc_slab(&self, slab_idx: usize, ptr: *mut u8) {
self.slab_descs[slab_idx].dealloc_chunk(ptr, &self.block_allocator)
}
}

View File

@@ -1,433 +0,0 @@
//! A slab allocator that carves out fixed-size chunks from larger blocks.
//!
//!
use std::alloc::Layout;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use spin;
use super::alloc_from_slice;
use super::block::BlockAllocator;
use crate::allocator::block::BLOCK_SIZE;
pub(crate) struct SlabDesc {
pub(crate) layout: Layout,
block_lists: spin::RwLock<BlockLists>,
pub(crate) num_blocks: AtomicU64,
pub(crate) num_allocated: AtomicU64,
}
// FIXME: Not sure if SlabDesc is really Sync or Send. It probably is when it's empty, but
// 'block_lists' contains pointers when it's not empty. In the current use as part of the
// the art tree, SlabDescs are only moved during initialization.
unsafe impl Sync for SlabDesc {}
unsafe impl Send for SlabDesc {}
#[derive(Default, Debug)]
struct BlockLists {
full_blocks: BlockList,
nonfull_blocks: BlockList,
}
impl BlockLists {
// Unlink a node. It must be in either one of the two lists.
unsafe fn unlink(&mut self, elem: *mut SlabBlockHeader) {
let list = unsafe {
if (*elem).next.is_null() {
if self.full_blocks.tail == elem {
Some(&mut self.full_blocks)
} else {
Some(&mut self.nonfull_blocks)
}
} else if (*elem).prev.is_null() {
if self.full_blocks.head == elem {
Some(&mut self.full_blocks)
} else {
Some(&mut self.nonfull_blocks)
}
} else {
None
}
};
unsafe { unlink_slab_block(list, elem) };
}
}
unsafe fn unlink_slab_block(mut list: Option<&mut BlockList>, elem: *mut SlabBlockHeader) {
unsafe {
if (*elem).next.is_null() {
assert_eq!(list.as_ref().unwrap().tail, elem);
list.as_mut().unwrap().tail = (*elem).prev;
} else {
assert_eq!((*(*elem).next).prev, elem);
(*(*elem).next).prev = (*elem).prev;
}
if (*elem).prev.is_null() {
assert_eq!(list.as_ref().unwrap().head, elem);
list.as_mut().unwrap().head = (*elem).next;
} else {
assert_eq!((*(*elem).prev).next, elem);
(*(*elem).prev).next = (*elem).next;
}
}
}
#[derive(Debug)]
struct BlockList {
head: *mut SlabBlockHeader,
tail: *mut SlabBlockHeader,
}
impl Default for BlockList {
fn default() -> Self {
BlockList {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
}
}
}
impl BlockList {
unsafe fn push_head(&mut self, elem: *mut SlabBlockHeader) {
unsafe {
if self.is_empty() {
self.tail = elem;
(*elem).next = std::ptr::null_mut();
} else {
(*elem).next = self.head;
(*self.head).prev = elem;
}
(*elem).prev = std::ptr::null_mut();
self.head = elem;
}
}
fn is_empty(&self) -> bool {
self.head.is_null()
}
unsafe fn unlink(&mut self, elem: *mut SlabBlockHeader) {
unsafe { unlink_slab_block(Some(self), elem) }
}
#[cfg(test)]
fn dump(&self) {
let mut next = self.head;
while !next.is_null() {
let n = unsafe { next.as_ref() }.unwrap();
eprintln!(
" blk {:?} (free {}/{})",
next,
n.num_free_chunks.load(Ordering::Relaxed),
n.num_chunks
);
next = n.next;
}
}
}
impl SlabDesc {
pub(crate) fn new(layout: &Layout) -> SlabDesc {
SlabDesc {
layout: *layout,
block_lists: spin::RwLock::new(BlockLists::default()),
num_allocated: AtomicU64::new(0),
num_blocks: AtomicU64::new(0),
}
}
}
#[derive(Debug)]
struct SlabBlockHeader {
free_chunks_head: spin::Mutex<*mut FreeChunk>,
num_free_chunks: AtomicU32,
num_chunks: u32, // this is really a constant for a given Layout
// these fields are protected by the lock on the BlockLists
prev: *mut SlabBlockHeader,
next: *mut SlabBlockHeader,
}
struct FreeChunk {
next: *mut FreeChunk,
}
enum ReadOrWriteGuard<'a, T> {
Read(spin::RwLockReadGuard<'a, T>),
Write(spin::RwLockWriteGuard<'a, T>),
}
impl<'a, T> Deref for ReadOrWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &<Self as Deref>::Target {
match self {
ReadOrWriteGuard::Read(g) => g.deref(),
ReadOrWriteGuard::Write(g) => g.deref(),
}
}
}
impl SlabDesc {
pub fn alloc_chunk(&self, block_allocator: &BlockAllocator) -> *mut u8 {
// Are there any free chunks?
let mut acquire_write = false;
'outer: loop {
let mut block_lists_guard = if acquire_write {
ReadOrWriteGuard::Write(self.block_lists.write())
} else {
ReadOrWriteGuard::Read(self.block_lists.read())
};
'inner: loop {
let block_ptr = block_lists_guard.nonfull_blocks.head;
if block_ptr.is_null() {
break 'outer;
}
unsafe {
let mut free_chunks_head = (*block_ptr).free_chunks_head.lock();
if !(*free_chunks_head).is_null() {
let result = *free_chunks_head;
(*free_chunks_head) = (*result).next;
let _old = (*block_ptr).num_free_chunks.fetch_sub(1, Ordering::Relaxed);
self.num_allocated.fetch_add(1, Ordering::Relaxed);
return result.cast();
}
}
// The block at the head of the list was full. Grab write lock and retry
match block_lists_guard {
ReadOrWriteGuard::Read(_) => {
acquire_write = true;
continue 'outer;
}
ReadOrWriteGuard::Write(ref mut g) => {
// move the node to the list of full blocks
unsafe {
g.nonfull_blocks.unlink(block_ptr);
g.full_blocks.push_head(block_ptr);
};
continue 'inner;
}
}
}
}
// no free chunks. Allocate a new block (and the chunk from that)
let (new_block, new_chunk) = self.alloc_block_and_chunk(block_allocator);
self.num_blocks.fetch_add(1, Ordering::Relaxed);
// Add the block to the list in the SlabDesc
unsafe {
let mut block_lists_guard = self.block_lists.write();
block_lists_guard.nonfull_blocks.push_head(new_block);
}
self.num_allocated.fetch_add(1, Ordering::Relaxed);
new_chunk
}
pub fn dealloc_chunk(&self, chunk_ptr: *mut u8, _block_allocator: &BlockAllocator) {
// Find the block it belongs to. You can find the block from the address. (And knowing the
// layout, you could calculate the chunk number too.)
let block_ptr: *mut SlabBlockHeader = {
let block_addr = (chunk_ptr.addr() / BLOCK_SIZE) * BLOCK_SIZE;
chunk_ptr.with_addr(block_addr).cast()
};
let chunk_ptr: *mut FreeChunk = chunk_ptr.cast();
// Mark the chunk as free in 'freechunks' list
let num_chunks;
let num_free_chunks;
unsafe {
let mut free_chunks_head = (*block_ptr).free_chunks_head.lock();
(*chunk_ptr).next = *free_chunks_head;
*free_chunks_head = chunk_ptr;
num_free_chunks = (*block_ptr).num_free_chunks.fetch_add(1, Ordering::Relaxed) + 1;
num_chunks = (*block_ptr).num_chunks;
}
if num_free_chunks == 1 {
// If the block was full previously, add it to the nonfull blocks list. Note that
// we're not holding the lock anymore, so it can immediately become full again.
// That's harmless, it will be moved back to the full list again when a call
// to alloc_chunk() sees it.
let mut block_lists = self.block_lists.write();
unsafe {
block_lists.unlink(block_ptr);
block_lists.nonfull_blocks.push_head(block_ptr);
};
} else if num_free_chunks == num_chunks {
// If the block became completely empty, move it to the free list
// TODO
// FIXME: we're still holding the spinlock. It's not exactly safe to return it to
// the free blocks list, is it? Defer it as garbage to wait out concurrent updates?
//block_allocator.release_block()
}
// update stats
self.num_allocated.fetch_sub(1, Ordering::Relaxed);
}
fn alloc_block_and_chunk(
&self,
block_allocator: &BlockAllocator,
) -> (*mut SlabBlockHeader, *mut u8) {
// fixme: handle OOM
let block_slice: &mut [MaybeUninit<u8>] = block_allocator.alloc_block();
let (block_header, remain) = alloc_from_slice::<SlabBlockHeader>(block_slice);
let padding = remain.as_ptr().align_offset(self.layout.align());
let num_chunks = (remain.len() - padding) / self.layout.size();
let first_chunk_ptr: *mut FreeChunk = remain[padding..].as_mut_ptr().cast();
unsafe {
let mut chunk_ptr = first_chunk_ptr;
for _ in 0..num_chunks - 1 {
let next_chunk_ptr = chunk_ptr.byte_add(self.layout.size());
(*chunk_ptr).next = next_chunk_ptr;
chunk_ptr = next_chunk_ptr;
}
(*chunk_ptr).next = std::ptr::null_mut();
let result_chunk = first_chunk_ptr;
let block_header = block_header.write(SlabBlockHeader {
free_chunks_head: spin::Mutex::new((*first_chunk_ptr).next),
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
num_chunks: num_chunks as u32,
num_free_chunks: AtomicU32::new(num_chunks as u32 - 1),
});
(block_header, result_chunk.cast())
}
}
#[cfg(test)]
fn dump(&self) {
eprintln!(
"slab dump ({} blocks, {} allocated chunks)",
self.num_blocks.load(Ordering::Relaxed),
self.num_allocated.load(Ordering::Relaxed)
);
let lists = self.block_lists.read();
eprintln!("nonfull blocks:");
lists.nonfull_blocks.dump();
eprintln!("full blocks:");
lists.full_blocks.dump();
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use rand_distr::Zipf;
struct TestObject {
val: usize,
_dummy: [u8; BLOCK_SIZE / 4],
}
struct TestObjectSlab<'a>(SlabDesc, BlockAllocator<'a>);
impl<'a> TestObjectSlab<'a> {
fn new(block_allocator: BlockAllocator) -> TestObjectSlab {
TestObjectSlab(SlabDesc::new(&Layout::new::<TestObject>()), block_allocator)
}
fn alloc(&self, val: usize) -> *mut TestObject {
let obj: *mut TestObject = self.0.alloc_chunk(&self.1).cast();
unsafe { (*obj).val = val };
obj
}
fn dealloc(&self, obj: *mut TestObject) {
self.0.dealloc_chunk(obj.cast(), &self.1)
}
}
#[test]
fn test_slab_alloc() {
const MEM_SIZE: usize = 100000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let block_allocator = BlockAllocator::new(&mut area);
let slab = TestObjectSlab::new(block_allocator);
let mut all: Vec<*mut TestObject> = Vec::new();
for i in 0..11 {
all.push(slab.alloc(i));
}
#[allow(clippy::needless_range_loop)]
for i in 0..11 {
assert!(unsafe { (*all[i]).val == i });
}
let distribution = Zipf::new(10.0, 1.1).unwrap();
let mut rng = rand::rng();
for _ in 0..100000 {
slab.0.dump();
let idx = rng.sample(distribution) as usize;
let ptr: *mut TestObject = all[idx];
if !ptr.is_null() {
assert_eq!(unsafe { (*ptr).val }, idx);
slab.dealloc(ptr);
all[idx] = std::ptr::null_mut();
} else {
all[idx] = slab.alloc(idx);
}
}
}
fn new_test_blk(i: u32) -> *mut SlabBlockHeader {
Box::into_raw(Box::new(SlabBlockHeader {
free_chunks_head: spin::Mutex::new(std::ptr::null_mut()),
num_free_chunks: AtomicU32::new(0),
num_chunks: i,
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
}))
}
#[test]
fn test_block_linked_list() {
// note: these are leaked, but that's OK for tests
let a = new_test_blk(0);
let b = new_test_blk(1);
let mut list = BlockList::default();
assert!(list.is_empty());
unsafe {
list.push_head(a);
assert!(!list.is_empty());
list.unlink(a);
}
assert!(list.is_empty());
unsafe {
list.push_head(b);
list.push_head(a);
assert_eq!(list.head, a);
assert_eq!((*a).next, b);
assert_eq!((*b).prev, a);
assert_eq!(list.tail, b);
list.unlink(a);
list.unlink(b);
assert!(list.is_empty());
}
}
}

View File

@@ -1,44 +0,0 @@
use std::mem::MaybeUninit;
pub fn alloc_from_slice<T>(
area: &mut [MaybeUninit<u8>],
) -> (&mut MaybeUninit<T>, &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size());
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { result_ptr.as_mut().unwrap() };
(result, remain)
}
pub fn alloc_array_from_slice<T>(
area: &mut [MaybeUninit<u8>],
len: usize,
) -> (&mut [MaybeUninit<T>], &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() * len > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size() * len);
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { std::slice::from_raw_parts_mut(result_ptr.as_mut().unwrap(), len) };
(result, remain)
}

View File

@@ -1,142 +0,0 @@
//! This is similar to crossbeam_epoch crate, but works in shared memory
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use crossbeam_utils::CachePadded;
const NUM_SLOTS: usize = 1000;
/// This is the struct that is stored in shmem
///
/// bit 0: is it pinned or not?
/// rest of the bits are the epoch counter.
pub struct EpochShared {
global_epoch: AtomicU64,
participants: [CachePadded<AtomicU64>; NUM_SLOTS],
broadcast_lock: spin::Mutex<()>,
}
impl EpochShared {
pub fn new() -> EpochShared {
EpochShared {
global_epoch: AtomicU64::new(2),
participants: [const { CachePadded::new(AtomicU64::new(2)) }; NUM_SLOTS],
broadcast_lock: spin::Mutex::new(()),
}
}
pub fn register(&self) -> LocalHandle {
LocalHandle {
global: self,
last_slot: AtomicUsize::new(0), // todo: choose more intelligently
}
}
fn release_pin(&self, slot: usize, _epoch: u64) {
let global_epoch = self.global_epoch.load(Ordering::Relaxed);
self.participants[slot].store(global_epoch, Ordering::Relaxed);
}
fn pin_internal(&self, slot_hint: usize) -> (usize, u64) {
// pick a slot
let mut slot = slot_hint;
let epoch = loop {
let old = self.participants[slot].fetch_or(1, Ordering::Relaxed);
if old & 1 == 0 {
// Got this slot
break old;
}
// the slot was busy by another thread / process. try a different slot
slot += 1;
if slot == NUM_SLOTS {
slot = 0;
}
continue;
};
(slot, epoch)
}
pub(crate) fn advance(&self) -> u64 {
// Advance the global epoch
let old_epoch = self.global_epoch.fetch_add(2, Ordering::Relaxed);
// Anyone that release their pin after this will update their slot.
old_epoch + 2
}
pub(crate) fn broadcast(&self) {
let Some(_guard) = self.broadcast_lock.try_lock() else {
return;
};
let epoch = self.global_epoch.load(Ordering::Relaxed);
let old_epoch = epoch.wrapping_sub(2);
// Update all free slots.
for i in 0..NUM_SLOTS {
// TODO: check result, as a sanity check. It should either be the old epoch, or pinned
let _ = self.participants[i].compare_exchange(
old_epoch,
epoch,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
// FIXME: memory fence here, since we used Relaxed?
}
pub(crate) fn get_oldest(&self) -> u64 {
// Read all slots.
let now = self.global_epoch.load(Ordering::Relaxed);
let mut oldest = now;
for i in 0..NUM_SLOTS {
let this_epoch = self.participants[i].load(Ordering::Relaxed);
let delta = now.wrapping_sub(this_epoch);
if delta > u64::MAX / 2 {
// this is very recent
} else if delta > now.wrapping_sub(oldest) {
oldest = this_epoch;
}
}
oldest
}
pub(crate) fn get_current(&self) -> u64 {
self.global_epoch.load(Ordering::Relaxed)
}
}
pub(crate) struct EpochPin<'e> {
slot: usize,
pub(crate) epoch: u64,
handle: &'e LocalHandle<'e>,
}
impl<'e> Drop for EpochPin<'e> {
fn drop(&mut self) {
self.handle.global.release_pin(self.slot, self.epoch);
}
}
pub struct LocalHandle<'g> {
global: &'g EpochShared,
last_slot: AtomicUsize,
}
impl<'g> LocalHandle<'g> {
pub fn pin(&self) -> EpochPin {
let (slot, epoch) = self
.global
.pin_internal(self.last_slot.load(Ordering::Relaxed));
self.last_slot.store(slot, Ordering::Relaxed);
EpochPin {
handle: self,
epoch,
slot,
}
}
}

View File

@@ -1,583 +0,0 @@
//! Adaptive Radix Tree (ART) implementation, with Optimistic Lock Coupling.
//!
//! The data structure is described in these two papers:
//!
//! [1] Leis, V. & Kemper, Alfons & Neumann, Thomas. (2013).
//! The adaptive radix tree: ARTful indexing for main-memory databases.
//! Proceedings - International Conference on Data Engineering. 38-49. 10.1109/ICDE.2013.6544812.
//! https://db.in.tum.de/~leis/papers/ART.pdf
//!
//! [2] Leis, Viktor & Scheibner, Florian & Kemper, Alfons & Neumann, Thomas. (2016).
//! The ART of practical synchronization.
//! 1-8. 10.1145/2933349.2933352.
//! https://db.in.tum.de/~leis/papers/artsync.pdf
//!
//! [1] describes the base data structure, and [2] describes the Optimistic Lock Coupling that we
//! use.
//!
//! The papers mention a few different variants. We have made the following choices in this
//! implementation:
//!
//! - All keys have the same length
//!
//! - Single-value leaves.
//!
//! - For collapsing inner nodes, we use the Pessimistic approach, where each inner node stores a
//! variable length "prefix", which stores the keys of all the one-way nodes which have been
//! removed. However, similar to the "hybrid" approach described in the paper, each node only has
//! space for a constant-size prefix of 8 bytes. If a node would have a longer prefix, then we
//! create create one-way nodes to store them. (There was no particular reason for this choice,
//! the "hybrid" approach described in the paper might be better.)
//!
//! - For concurrency, we use Optimistic Lock Coupling. The paper [2] also describes another method,
//! ROWEX, which generally performs better when there is contention, but that is not important
//! for use and Optimisic Lock Coupling is simpler to implement.
//!
//! ## Requirements
//!
//! This data structure is currently used for the integrated LFC, relsize and last-written LSN cache
//! in the compute communicator, part of the 'neon' Postgres extension. We have some unique
//! requirements, which is why we had to write our own. Namely:
//!
//! - The data structure has to live in fixed-sized shared memory segment. That rules out any
//! built-in Rust collections and most crates. (Except possibly with the 'allocator_api' rust
//! feature, which still nightly-only experimental as of this writing).
//!
//! - The data structure is accessed from multiple processes. Only one process updates the data
//! structure, but other processes perform reads. That rules out using built-in Rust locking
//! primitives like Mutex and RwLock, and most crates too.
//!
//! - Within the one process with write-access, multiple threads can perform updates concurrently.
//! That rules out using PostgreSQL LWLocks for the locking.
//!
//! The implementation is generic, and doesn't depend on any PostgreSQL specifics, but it has been
//! written with that usage and the above constraints in mind. Some noteworthy assumptions:
//!
//! - Contention is assumed to be rare. In the integrated cache in PostgreSQL, there's higher level
//! locking in the PostgreSQL buffer manager, which ensures that two backends should not try to
//! read / write the same page at the same time. (Prefetching can conflict with actual reads,
//! however.)
//!
//! - The keys in the integrated cache are 17 bytes long.
//!
//! ## Usage
//!
//! Because this is designed to be used as a Postgres shared memory data structure, initialization
//! happens in three stages:
//!
//! 0. A fixed area of shared memory is allocated at postmaster startup.
//!
//! 1. TreeInitStruct::new() is called to initialize it, still in Postmaster process, before any
//! other process or thread is running. It returns a TreeInitStruct, which is inherited by all
//! the processes through fork().
//!
//! 2. One process may have write-access to the struct, by calling
//! [TreeInitStruct::attach_writer]. (That process is the communicator process.)
//!
//! 3. Other processes get read-access to the struct, by calling [TreeInitStruct::attach_reader]
//!
//! "Write access" means that you can insert / update / delete values in the tree.
//!
//! NOTE: The Values stored in the tree are sometimes moved, when a leaf node fills up and a new
//! larger node needs to be allocated. The versioning and epoch-based allocator ensure that the data
//! structure stays consistent, but if the Value has interior mutability, like atomic fields,
//! updates to such fields might be lost if the leaf node is concurrently moved! If that becomes a
//! problem, the version check could be passed up to the caller, so that the caller could detect the
//! lost updates and retry the operation.
//!
//! ## Implementation
//!
//! node_ptr: Provides low-level implementations of the four different node types (eight actually,
//! since there is an Internal and Leaf variant of each)
//!
//! lock_and_version.rs: Provides an abstraction for the combined lock and version counter on each
//! node.
//!
//! node_ref.rs: The code in node_ptr.rs deals with raw pointers. node_ref.rs provides more type-safe
//! abstractions on top.
//!
//! algorithm.rs: Contains the functions to implement lookups and updates in the tree
//!
//! allocator.rs: Provides a facility to allocate memory for the tree nodes. (We must provide our
//! own abstraction for that because we need the data structure to live in a pre-allocated shared
//! memory segment).
//!
//! epoch.rs: The data structure requires that when a node is removed from the tree, it is not
//! immediately deallocated, but stays around for as long as concurrent readers might still have
//! pointers to them. This is enforced by an epoch system. This is similar to
//! e.g. crossbeam_epoch, but we couldn't use that either because it has to work across processes
//! communicating over the shared memory segment.
//!
//! ## See also
//!
//! There are some existing Rust ART implementations out there, but none of them filled all
//! the requirements:
//!
//! - https://github.com/XiangpengHao/congee
//! - https://github.com/declanvk/blart
//!
//! ## TODO
//!
//! - Removing values has not been implemented
mod algorithm;
pub mod allocator;
mod epoch;
use algorithm::RootPtr;
use algorithm::node_ptr::NodePtr;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::epoch::EpochPin;
#[cfg(test)]
mod tests;
use allocator::ArtAllocator;
pub use allocator::ArtMultiSlabAllocator;
pub use allocator::OutOfMemoryError;
/// Fixed-length key type.
///
pub trait Key: Debug {
const KEY_LEN: usize;
fn as_bytes(&self) -> &[u8];
}
/// Values stored in the tree
///
/// Values need to be Cloneable, because when a node "grows", the value is copied to a new node and
/// the old sticks around until all readers that might see the old value are gone.
// fixme obsolete, no longer needs Clone
pub trait Value {}
const MAX_GARBAGE: usize = 1024;
/// The root of the tree, plus other tree-wide data. This is stored in the shared memory.
pub struct Tree<V: Value> {
/// For simplicity, so that we never need to grow or shrink the root, the root node is always an
/// Internal256 node. Also, it never has a prefix (that's actually a bit wasteful, incurring one
/// indirection to every lookup)
root: RootPtr<V>,
writer_attached: AtomicBool,
epoch: epoch::EpochShared,
}
unsafe impl<V: Value + Sync> Sync for Tree<V> {}
unsafe impl<V: Value + Send> Send for Tree<V> {}
struct GarbageQueue<V>(VecDeque<(NodePtr<V>, u64)>);
unsafe impl<V: Value + Sync> Sync for GarbageQueue<V> {}
unsafe impl<V: Value + Send> Send for GarbageQueue<V> {}
impl<V> GarbageQueue<V> {
fn new() -> GarbageQueue<V> {
GarbageQueue(VecDeque::with_capacity(MAX_GARBAGE))
}
fn remember_obsolete_node(&mut self, ptr: NodePtr<V>, epoch: u64) {
self.0.push_front((ptr, epoch));
}
fn next_obsolete(&mut self, cutoff_epoch: u64) -> Option<NodePtr<V>> {
if let Some(back) = self.0.back() {
if back.1 < cutoff_epoch {
return Some(self.0.pop_back().unwrap().0);
}
}
None
}
}
/// Struct created at postmaster startup
pub struct TreeInitStruct<'t, K: Key, V: Value, A: ArtAllocator<V>> {
tree: &'t Tree<V>,
allocator: &'t A,
phantom_key: PhantomData<K>,
}
/// The worker process has a reference to this. The write operations are only safe
/// from the worker process
pub struct TreeWriteAccess<'t, K: Key, V: Value, A: ArtAllocator<V>>
where
K: Key,
V: Value,
{
tree: &'t Tree<V>,
pub allocator: &'t A,
epoch_handle: epoch::LocalHandle<'t>,
phantom_key: PhantomData<K>,
/// Obsolete nodes that cannot be recycled until their epoch expires.
garbage: spin::Mutex<GarbageQueue<V>>,
}
/// The backends have a reference to this. It cannot be used to modify the tree
pub struct TreeReadAccess<'t, K: Key, V: Value>
where
K: Key,
V: Value,
{
tree: &'t Tree<V>,
epoch_handle: epoch::LocalHandle<'t>,
phantom_key: PhantomData<K>,
}
impl<'t, K: Key, V: Value, A: ArtAllocator<V>> TreeInitStruct<'t, K, V, A> {
pub fn new(allocator: &'t A) -> TreeInitStruct<'t, K, V, A> {
let tree_ptr = allocator.alloc_tree();
let tree_ptr = NonNull::new(tree_ptr).expect("out of memory");
let init = Tree {
root: algorithm::new_root(allocator).expect("out of memory"),
writer_attached: AtomicBool::new(false),
epoch: epoch::EpochShared::new(),
};
unsafe { tree_ptr.write(init) };
TreeInitStruct {
tree: unsafe { tree_ptr.as_ref() },
allocator,
phantom_key: PhantomData,
}
}
pub fn attach_writer(self) -> TreeWriteAccess<'t, K, V, A> {
let previously_attached = self.tree.writer_attached.swap(true, Ordering::Relaxed);
if previously_attached {
panic!("writer already attached");
}
TreeWriteAccess {
tree: self.tree,
allocator: self.allocator,
phantom_key: PhantomData,
epoch_handle: self.tree.epoch.register(),
garbage: spin::Mutex::new(GarbageQueue::new()),
}
}
pub fn attach_reader(self) -> TreeReadAccess<'t, K, V> {
TreeReadAccess {
tree: self.tree,
phantom_key: PhantomData,
epoch_handle: self.tree.epoch.register(),
}
}
}
impl<'t, K: Key, V: Value, A: ArtAllocator<V>> TreeWriteAccess<'t, K, V, A> {
pub fn start_write<'g>(&'t self) -> TreeWriteGuard<'g, K, V, A>
where
't: 'g,
{
TreeWriteGuard {
tree_writer: self,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
created_garbage: false,
}
}
pub fn start_read(&'t self) -> TreeReadGuard<'t, K, V> {
TreeReadGuard {
tree: self.tree,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
}
}
}
impl<'t, K: Key, V: Value> TreeReadAccess<'t, K, V> {
pub fn start_read(&'t self) -> TreeReadGuard<'t, K, V> {
TreeReadGuard {
tree: self.tree,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
}
}
}
pub struct TreeReadGuard<'e, K, V>
where
K: Key,
V: Value,
{
tree: &'e Tree<V>,
epoch_pin: EpochPin<'e>,
phantom_key: PhantomData<K>,
}
impl<'e, K: Key, V: Value> TreeReadGuard<'e, K, V> {
pub fn get(&'e self, key: &K) -> Option<&'e V> {
algorithm::search(key, self.tree.root, &self.epoch_pin)
}
}
pub struct TreeWriteGuard<'e, K, V, A>
where
K: Key,
V: Value,
A: ArtAllocator<V>,
{
tree_writer: &'e TreeWriteAccess<'e, K, V, A>,
epoch_pin: EpochPin<'e>,
phantom_key: PhantomData<K>,
created_garbage: bool,
}
pub enum UpdateAction<V> {
Nothing,
Insert(V),
Remove,
}
impl<'e, K: Key, V: Value, A: ArtAllocator<V>> TreeWriteGuard<'e, K, V, A> {
/// Get a value
pub fn get(&'e mut self, key: &K) -> Option<&'e V> {
algorithm::search(key, self.tree_writer.tree.root, &self.epoch_pin)
}
/// Insert a value
pub fn insert(self, key: &K, value: V) -> Result<bool, OutOfMemoryError> {
let mut success = None;
self.update_with_fn(key, |existing| {
if existing.is_some() {
success = Some(false);
UpdateAction::Nothing
} else {
success = Some(true);
UpdateAction::Insert(value)
}
})?;
Ok(success.expect("value_fn not called"))
}
/// Remove value. Returns true if it existed
pub fn remove(self, key: &K) -> bool {
let mut result = false;
// FIXME: It's not clear if OOM is expected while removing. It seems
// not nice, but shrinking a node can OOM. Then again, we could opt
// to not shrink a node if we cannot allocate, to live a little longer.
self.update_with_fn(key, |existing| match existing {
Some(_) => {
result = true;
UpdateAction::Remove
}
None => UpdateAction::Nothing,
})
.expect("out of memory while removing");
result
}
/// Try to remove value and return the old value.
pub fn remove_and_return(self, key: &K) -> Option<V>
where
V: Clone,
{
let mut old = None;
self.update_with_fn(key, |existing| {
old = existing.cloned();
UpdateAction::Remove
})
.expect("out of memory while removing");
old
}
/// Update key using the given function. All the other modifying operations are based on this.
///
/// The function is passed a reference to the existing value, if any. If the function
/// returns None, the value is removed from the tree (or if there was no existing value,
/// does nothing). If the function returns Some, the existing value is replaced, of if there
/// was no existing value, it is inserted. FIXME: update comment
pub fn update_with_fn<F>(mut self, key: &K, value_fn: F) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
algorithm::update_fn(key, value_fn, self.tree_writer.tree.root, &mut self)?;
if self.created_garbage {
let _ = self.collect_garbage();
}
Ok(())
}
fn remember_obsolete_node(&mut self, ptr: NodePtr<V>) {
self.tree_writer
.garbage
.lock()
.remember_obsolete_node(ptr, self.epoch_pin.epoch);
self.created_garbage = true;
}
// returns number of nodes recycled
fn collect_garbage(&self) -> usize {
self.tree_writer.tree.epoch.advance();
self.tree_writer.tree.epoch.broadcast();
let cutoff_epoch = self.tree_writer.tree.epoch.get_oldest();
let mut result = 0;
let mut garbage_queue = self.tree_writer.garbage.lock();
while let Some(ptr) = garbage_queue.next_obsolete(cutoff_epoch) {
ptr.deallocate(self.tree_writer.allocator);
result += 1;
}
result
}
}
pub struct TreeIterator<K>
where
K: Key + for<'a> From<&'a [u8]>,
{
done: bool,
pub next_key: Vec<u8>,
max_key: Option<Vec<u8>>,
phantom_key: PhantomData<K>,
}
impl<K> TreeIterator<K>
where
K: Key + for<'a> From<&'a [u8]>,
{
pub fn new_wrapping() -> TreeIterator<K> {
TreeIterator {
done: false,
next_key: vec![0; K::KEY_LEN],
max_key: None,
phantom_key: PhantomData,
}
}
pub fn new(range: &std::ops::Range<K>) -> TreeIterator<K> {
let result = TreeIterator {
done: false,
next_key: Vec::from(range.start.as_bytes()),
max_key: Some(Vec::from(range.end.as_bytes())),
phantom_key: PhantomData,
};
assert_eq!(result.next_key.len(), K::KEY_LEN);
assert_eq!(result.max_key.as_ref().unwrap().len(), K::KEY_LEN);
result
}
pub fn next<'g, V>(&mut self, read_guard: &'g TreeReadGuard<'g, K, V>) -> Option<(K, &'g V)>
where
V: Value,
{
if self.done {
return None;
}
let mut wrapped_around = false;
loop {
assert_eq!(self.next_key.len(), K::KEY_LEN);
if let Some((k, v)) =
algorithm::iter_next(&self.next_key, read_guard.tree.root, &read_guard.epoch_pin)
{
assert_eq!(k.len(), K::KEY_LEN);
assert_eq!(self.next_key.len(), K::KEY_LEN);
// Check if we reached the end of the range
if let Some(max_key) = &self.max_key {
if k.as_slice() >= max_key.as_slice() {
self.done = true;
break None;
}
}
// increment the key
self.next_key = k.clone();
increment_key(self.next_key.as_mut_slice());
let k = k.as_slice().into();
break Some((k, v));
} else {
if self.max_key.is_some() {
self.done = true;
} else {
// Start from beginning
if !wrapped_around {
for i in 0..K::KEY_LEN {
self.next_key[i] = 0;
}
wrapped_around = true;
continue;
} else {
// The tree is completely empty
// FIXME: perhaps we should remember the starting point instead.
// Currently this will scan some ranges twice.
break None;
}
}
break None;
}
}
}
}
fn increment_key(key: &mut [u8]) -> bool {
for i in (0..key.len()).rev() {
let (byte, overflow) = key[i].overflowing_add(1);
key[i] = byte;
if !overflow {
return false;
}
}
true
}
// Debugging functions
impl<'e, K: Key, V: Value + Debug, A: ArtAllocator<V>> TreeWriteGuard<'e, K, V, A> {
pub fn dump(&mut self, dst: &mut dyn std::io::Write) {
algorithm::dump_tree(self.tree_writer.tree.root, &self.epoch_pin, dst)
}
}
impl<'e, K: Key, V: Value + Debug> TreeReadGuard<'e, K, V> {
pub fn dump(&mut self, dst: &mut dyn std::io::Write) {
algorithm::dump_tree(self.tree.root, &self.epoch_pin, dst)
}
}
impl<'e, K: Key, V: Value> TreeWriteAccess<'e, K, V, ArtMultiSlabAllocator<'e, V>> {
pub fn get_statistics(&self) -> ArtTreeStatistics {
self.allocator.get_statistics();
ArtTreeStatistics {
blocks: self.allocator.inner.block_allocator.get_statistics(),
slabs: self.allocator.get_statistics(),
epoch: self.tree.epoch.get_current(),
oldest_epoch: self.tree.epoch.get_oldest(),
num_garbage: self.garbage.lock().0.len() as u64,
}
}
}
#[derive(Clone, Debug)]
pub struct ArtTreeStatistics {
pub blocks: allocator::block::BlockAllocatorStats,
pub slabs: allocator::ArtMultiSlabStats,
pub epoch: u64,
pub oldest_epoch: u64,
pub num_garbage: u64,
}

View File

@@ -1,236 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::ArtAllocator;
use crate::ArtMultiSlabAllocator;
use crate::TreeInitStruct;
use crate::TreeIterator;
use crate::TreeWriteAccess;
use crate::UpdateAction;
use crate::{Key, Value};
use rand::Rng;
use rand::seq::SliceRandom;
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl TestKey {
const MIN: TestKey = TestKey([0; TEST_KEY_LEN]);
const MAX: TestKey = TestKey([u8::MAX; TEST_KEY_LEN]);
}
impl Key for TestKey {
const KEY_LEN: usize = TEST_KEY_LEN;
fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
impl Value for usize {}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MEM_SIZE: usize = 10000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let allocator = ArtMultiSlabAllocator::new(&mut area);
let init_struct = TreeInitStruct::<TestKey, usize, _>::new(allocator);
let tree_writer = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let w = tree_writer.start_write();
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let r = tree_writer.start_read();
let value = r.get(&(*k).into());
assert_eq!(value, Some(idx).as_ref());
}
eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.contains(&key) {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Value for TestValue {}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op<A: ArtAllocator<TestValue>>(
op: &TestOp,
tree: &TreeWriteAccess<TestKey, TestValue, A>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
let w = tree.start_write();
w.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
fn test_iter<A: ArtAllocator<TestValue>>(
tree: &TreeWriteAccess<TestKey, TestValue, A>,
shadow: &BTreeMap<TestKey, usize>,
) {
let mut shadow_iter = shadow.iter();
let mut iter = TreeIterator::new(&(TestKey::MIN..TestKey::MAX));
loop {
let shadow_item = shadow_iter.next().map(|(k, v)| (*k, *v));
let r = tree.start_read();
let item = iter.next(&r);
if shadow_item != item.map(|(k, v)| (k, v.load())) {
eprintln!("FAIL: iterator returned {item:?}, expected {shadow_item:?}");
tree.start_read().dump(&mut std::io::stderr());
eprintln!("SHADOW:");
for si in shadow {
eprintln!("key: {:?}, val: {}", si.0, si.1);
}
panic!("FAIL: iterator returned {item:?}, expected {shadow_item:?}");
}
if item.is_none() {
break;
}
}
}
#[test]
fn random_ops() {
const MEM_SIZE: usize = 10000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let allocator = ArtMultiSlabAllocator::new(&mut area);
let init_struct = TreeInitStruct::<TestKey, TestValue, _>::new(allocator);
let tree_writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let mut key: TestKey = (rng.sample(distribution) as u128).into();
if rng.random_bool(0.10) {
key = TestKey::from(u128::from(&key) | 0xffffffff);
}
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &tree_writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
eprintln!("stats: {:?}", tree_writer.get_statistics());
test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -394,7 +394,7 @@ impl From<&OtelExporterConfig> for tracing_utils::ExportConfig {
tracing_utils::ExportConfig {
endpoint: Some(val.endpoint.clone()),
protocol: val.protocol.into(),
timeout: val.timeout,
timeout: Some(val.timeout),
}
}
}

View File

@@ -596,6 +596,7 @@ pub struct TimelineImportRequest {
pub timeline_id: TimelineId,
pub start_lsn: Lsn,
pub sk_set: Vec<NodeId>,
pub force_upsert: bool,
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]

View File

@@ -981,12 +981,12 @@ mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let key = Key {
field1: rng.r#gen(),
field2: rng.r#gen(),
field3: rng.r#gen(),
field4: rng.r#gen(),
field5: rng.r#gen(),
field6: rng.r#gen(),
field1: rng.random(),
field2: rng.random(),
field3: rng.random(),
field4: rng.random(),
field5: rng.random(),
field6: rng.random(),
};
assert_eq!(key, Key::from_str(&format!("{key}")).unwrap());

View File

@@ -443,9 +443,9 @@ pub struct ImportPgdataIdempotencyKey(pub String);
impl ImportPgdataIdempotencyKey {
pub fn random() -> Self {
use rand::Rng;
use rand::distributions::Alphanumeric;
use rand::distr::Alphanumeric;
Self(
rand::thread_rng()
rand::rng()
.sample_iter(&Alphanumeric)
.take(20)
.map(char::from)

View File

@@ -69,22 +69,6 @@ impl Hash for ShardIdentity {
}
}
/// Stripe size in number of pages
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
impl Default for ShardStripeSize {
fn default() -> Self {
DEFAULT_STRIPE_SIZE
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
/// Layout version: for future upgrades where we might change how the key->shard mapping works
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash, Debug)]
pub struct ShardLayout(u8);

View File

@@ -21,6 +21,14 @@ pub struct ReAttachRequest {
/// if the node already has a node_id set.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub register: Option<NodeRegisterRequest>,
/// Hadron: Optional flag to indicate whether the node is starting with an empty local disk.
/// Will be set to true if the node couldn't find any local tenant data on startup, could be
/// due to the node starting for the first time or due to a local SSD failure/disk wipe event.
/// The flag may be used by the storage controller to update its observed state of the world
/// to make sure that it sends explicit location_config calls to the node following the
/// re-attach request.
pub empty_local_disk: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug)]

View File

@@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData {
}
}
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
use rand::distr::{Distribution, StandardUniform};
impl Distribution<CancelKeyData> for StandardUniform {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.r#gen(),
cancel_key: rng.r#gen(),
backend_pid: rng.random(),
cancel_key: rng.random(),
}
}
}

View File

@@ -155,10 +155,10 @@ pub struct ScramSha256 {
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
let mut v = rng.random_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}

View File

@@ -74,7 +74,6 @@ impl Header {
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
@@ -145,16 +144,7 @@ impl Message {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
NOTIFICATION_RESPONSE_TAG => Message::NotificationResponse(NotificationResponseBody {}),
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
@@ -543,28 +533,7 @@ impl NoticeResponseBody {
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct NotificationResponseBody {}
pub struct ParameterDescriptionBody {
storage: Bytes,

View File

@@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}

View File

@@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
@@ -221,6 +221,18 @@ impl Client {
&mut self.inner
}
pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
let (tx, rx) = mpsc::unbounded_channel();
let notices = RecordNotices { sender: tx, limit };
self.inner
.sender
.send(FrontendMessage::RecordNotices(notices))
.ok();
rx
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(

View File

@@ -3,10 +3,17 @@ use std::io;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use tokio::sync::mpsc::UnboundedSender;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
RecordNotices(RecordNotices),
}
pub struct RecordNotices {
pub sender: UnboundedSender<Box<str>>,
pub limit: usize,
}
pub enum BackendMessage {
@@ -33,14 +40,11 @@ impl FallibleIterator for BackendMessages {
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
impl Encoder<Bytes> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
}
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item);
Ok(())
}
}

View File

@@ -1,11 +1,9 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
@@ -48,8 +46,8 @@ where
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
parameters: _,
delayed_notice: _,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
@@ -72,13 +70,7 @@ where
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))
}

View File

@@ -3,7 +3,7 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::authentication::sasl;
@@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use crate::Error;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
use crate::config::{self, AuthKeys, Config};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::TlsStream;
@@ -25,7 +25,7 @@ pub struct StartupStream<S, T> {
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
impl<S, T> Sink<Bytes> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
@@ -36,7 +36,7 @@ where
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
@@ -120,10 +120,7 @@ where
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -191,10 +188,7 @@ where
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
@@ -253,10 +247,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
@@ -272,10 +263,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,

View File

@@ -1,22 +1,23 @@
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use futures_util::{Sink, Stream, ready};
use postgres_protocol2::message::backend::Message;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, StreamExt, ready};
use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
use tracing::{info, trace};
use tracing::trace;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::Error;
use crate::codec::{
BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
#[derive(PartialEq, Debug)]
enum State {
@@ -33,18 +34,18 @@ enum State {
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
sender: PollSender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
notices: Option<RecordNotices>,
pending_responses: VecDeque<BackendMessage>,
pending_response: Option<BackendMessages>,
state: State,
}
pub enum Never {}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
@@ -52,70 +53,42 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
sender: mpsc::Sender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
sender: PollSender::new(sender),
receiver,
pending_responses,
notices: None,
pending_response: None,
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
}
};
let messages = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
let messages = match self.pending_response.take() {
Some(messages) => messages,
None => {
let message = match self.stream.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
Poll::Ready(Some(Ok(message))) => message,
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
self.handle_notice(body)?;
continue;
}
BackendMessage::Async(_) => continue,
BackendMessage::Normal { messages } => messages,
}
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal { messages } => messages,
};
match self.sender.poll_reserve(cx) {
@@ -126,8 +99,7 @@ where
return Poll::Ready(Err(Error::closed()));
}
Poll::Pending => {
self.pending_responses
.push_back(BackendMessage::Normal { messages });
self.pending_response = Some(messages);
trace!("poll_read: waiting on sender");
return Poll::Pending;
}
@@ -135,6 +107,31 @@ where
}
}
fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
let Some(notices) = &mut self.notices else {
return Ok(());
};
let mut fields = body.fields();
while let Some(field) = fields.next().map_err(Error::parse)? {
// loop until we find the message field
if field.type_() == b'M' {
// if the message field is within the limit, send it.
if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
match notices.sender.send(field.value().into()) {
// set the new limit.
Ok(()) => notices.limit = new_limit,
// closed.
Err(_) => self.notices = None,
}
}
break;
}
}
Ok(())
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
if self.receiver.is_closed() {
@@ -168,21 +165,23 @@ where
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(request)) => {
Poll::Ready(Some(FrontendMessage::Raw(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
self.notices = Some(notices)
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) => {
trace!("poll_write: at eof, terminating");
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.start_send(request.freeze())
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
@@ -231,34 +230,17 @@ where
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(()) = closing {
let Poll::Pending = self.poll_read(cx)?;
if self.poll_write(cx)?.is_ready() {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
// poll_write returned Pending or Ready(()).
// if poll_write returned Ready(()), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
@@ -280,11 +262,9 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
match self.poll_message(cx)? {
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
Poll::Ready(Ok(()))
}
}

View File

@@ -8,7 +8,6 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -93,21 +92,6 @@ impl Notification {
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]

View File

@@ -43,7 +43,7 @@ itertools.workspace = true
sync_wrapper = { workspace = true, features = ["futures"] }
byteorder = "1.4"
rand = "0.8.5"
rand.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -81,7 +81,7 @@ impl UnreliableWrapper {
///
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
let mut attempts = self.attempts.lock().unwrap();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match attempts.entry(op) {
Entry::Occupied(mut e) => {
@@ -94,7 +94,7 @@ impl UnreliableWrapper {
/* BEGIN_HADRON */
// If there are more attempts to fail, fail the request by probability.
if (attempts_before_this < self.attempts_to_fail)
&& (rng.gen_range(0..=100) < self.attempt_failure_probability)
&& (rng.random_range(0..=100) < self.attempt_failure_probability)
{
let error =
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());

View File

@@ -208,7 +208,7 @@ async fn create_azure_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AzureContainer(AzureConfig {

View File

@@ -385,7 +385,7 @@ async fn create_s3_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AwsS3(S3Config {

View File

@@ -8,7 +8,7 @@ license.workspace = true
hyper0.workspace = true
opentelemetry = { workspace = true, features = ["trace"] }
opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] }
opentelemetry-semantic-conventions.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true

View File

@@ -1,11 +1,5 @@
//! Helper functions to set up OpenTelemetry tracing.
//!
//! This comes in two variants, depending on whether you have a Tokio runtime available.
//! If you do, call `init_tracing()`. It sets up the trace processor and exporter to use
//! the current tokio runtime. If you don't have a runtime available, or you don't want
//! to share the runtime with the tracing tasks, call `init_tracing_without_runtime()`
//! instead. It sets up a dedicated single-threaded Tokio runtime for the tracing tasks.
//!
//! Example:
//!
//! ```rust,no_run
@@ -21,7 +15,8 @@
//! .with_writer(std::io::stderr);
//!
//! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces
//! let otlp_layer = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default()).await;
//! let provider = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default());
//! let otlp_layer = provider.as_ref().map(tracing_utils::layer);
//!
//! // Put it all together
//! tracing_subscriber::registry()
@@ -36,16 +31,18 @@
pub mod http;
pub mod perf_span;
use opentelemetry::KeyValue;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
pub use opentelemetry_otlp::{ExportConfig, Protocol};
use opentelemetry_sdk::trace::SdkTracerProvider;
use tracing::level_filters::LevelFilter;
use tracing::{Dispatch, Subscriber};
use tracing_subscriber::Layer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
pub type Provider = SdkTracerProvider;
/// Set up OpenTelemetry exporter, using configuration from environment variables.
///
/// `service_name` is set as the OpenTelemetry 'service.name' resource (see
@@ -70,16 +67,7 @@ use tracing_subscriber::registry::LookupSpan;
/// If you need some other setting, please test if it works first. And perhaps
/// add a comment in the list above to save the effort of testing for the next
/// person.
///
/// This doesn't block, but is marked as 'async' to hint that this must be called in
/// asynchronous execution context.
pub async fn init_tracing<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
pub fn init_tracing(service_name: &str, export_config: ExportConfig) -> Option<Provider> {
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -89,52 +77,14 @@ where
))
}
/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing
/// tasks.
pub fn init_tracing_without_runtime<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
pub fn layer<S>(p: &Provider) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
// The opentelemetry batch processor and the OTLP exporter needs a Tokio
// runtime. Create a dedicated runtime for them. One thread should be
// enough.
//
// (Alternatively, instead of batching, we could use the "simple
// processor", which doesn't need Tokio, and use "reqwest-blocking"
// feature for the OTLP exporter, which also doesn't need Tokio. However,
// batching is considered best practice, and also I have the feeling that
// the non-Tokio codepaths in the opentelemetry crate are less used and
// might be more buggy, so better to stay on the well-beaten path.)
//
// We leak the runtime so that it keeps running after we exit the
// function.
let runtime = Box::leak(Box::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("otlp runtime thread")
.worker_threads(1)
.build()
.unwrap(),
));
let _guard = runtime.enter();
Some(init_tracing_internal(
service_name.to_string(),
export_config,
))
tracing_opentelemetry::layer().with_tracer(p.tracer("global"))
}
fn init_tracing_internal<S>(service_name: String, export_config: ExportConfig) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> Provider {
// Sets up exporter from the provided [`ExportConfig`] parameter.
// If the endpoint is not specified, it is loaded from the
// OTEL_EXPORTER_OTLP_ENDPOINT environment variable.
@@ -153,22 +103,14 @@ where
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);
let tracer = opentelemetry_sdk::trace::TracerProvider::builder()
.with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio)
.with_resource(opentelemetry_sdk::Resource::new(vec![KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
service_name,
)]))
Provider::builder()
.with_batch_exporter(exporter)
.with_resource(
opentelemetry_sdk::Resource::builder()
.with_service_name(service_name)
.build(),
)
.build()
.tracer("global");
tracing_opentelemetry::layer().with_tracer(tracer)
}
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit.
pub fn shutdown_tracing() {
opentelemetry::global::shutdown_tracer_provider();
}
pub enum OtelEnablement {
@@ -176,17 +118,17 @@ pub enum OtelEnablement {
Enabled {
service_name: String,
export_config: ExportConfig,
runtime: &'static tokio::runtime::Runtime,
},
}
pub struct OtelGuard {
provider: Provider,
pub dispatch: Dispatch,
}
impl Drop for OtelGuard {
fn drop(&mut self) {
shutdown_tracing();
_ = self.provider.shutdown();
}
}
@@ -199,22 +141,19 @@ impl Drop for OtelGuard {
/// The lifetime of the guard should match taht of the application. On drop, it tears down the
/// OTEL infra.
pub fn init_performance_tracing(otel_enablement: OtelEnablement) -> Option<OtelGuard> {
let otel_subscriber = match otel_enablement {
match otel_enablement {
OtelEnablement::Disabled => None,
OtelEnablement::Enabled {
service_name,
export_config,
runtime,
} => {
let otel_layer = runtime
.block_on(init_tracing(&service_name, export_config))
.with_filter(LevelFilter::INFO);
let provider = init_tracing(&service_name, export_config)?;
let otel_layer = layer(&provider).with_filter(LevelFilter::INFO);
let otel_subscriber = tracing_subscriber::registry().with(otel_layer);
let otel_dispatch = Dispatch::new(otel_subscriber);
let dispatch = Dispatch::new(otel_subscriber);
Some(otel_dispatch)
Some(OtelGuard { dispatch, provider })
}
};
otel_subscriber.map(|dispatch| OtelGuard { dispatch })
}
}

View File

@@ -104,7 +104,7 @@ impl Id {
pub fn generate() -> Self {
let mut tli_buf = [0u8; 16];
rand::thread_rng().fill(&mut tli_buf);
rand::rng().fill(&mut tli_buf);
Id::from(tli_buf)
}

View File

@@ -364,42 +364,37 @@ impl MonotonicCounter<Lsn> for RecordLsn {
}
}
/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s.
/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s.
///
/// This is used by the `pagebench` pageserver benchmarking tool.
pub struct LsnSampler(<u64 as rand::distributions::uniform::SampleUniform>::Sampler);
pub struct LsnSampler(<u64 as rand::distr::uniform::SampleUniform>::Sampler);
impl rand::distributions::uniform::SampleUniform for Lsn {
impl rand::distr::uniform::SampleUniform for Lsn {
type Sampler = LsnSampler;
}
impl rand::distributions::uniform::UniformSampler for LsnSampler {
impl rand::distr::uniform::UniformSampler for LsnSampler {
type X = Lsn;
fn new<B1, B2>(low: B1, high: B2) -> Self
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new(
low.borrow().0,
high.borrow().0,
),
)
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new(low.borrow().0, high.borrow().0)
.map(Self)
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
),
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
)
.map(Self)
}
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> Self::X {

View File

@@ -25,6 +25,12 @@ pub struct ShardIndex {
pub shard_count: ShardCount,
}
/// Stripe size as number of pages.
///
/// NB: don't implement Default, so callers don't lazily use it by mistake. See DEFAULT_STRIPE_SIZE.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
/// Formatting helper, for generating the `shard_id` label in traces.
pub struct ShardSlug<'a>(&'a TenantShardId);
@@ -181,6 +187,12 @@ impl std::fmt::Display for ShardCount {
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for ShardSlug<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(

View File

@@ -11,7 +11,8 @@ use pageserver::tenant::layer_map::LayerMap;
use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc};
use pageserver_api::key::Key;
use pageserver_api::shard::TenantShardId;
use rand::prelude::{SeedableRng, SliceRandom, StdRng};
use rand::prelude::{SeedableRng, StdRng};
use rand::seq::IndexedRandom;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;

View File

@@ -14,12 +14,11 @@ use utils::logging::warn_slow;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_api::shard::ShardStripeSize;
use pageserver_page_api as page_api;
use pageserver_page_api::GetPageSplitter;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};
/// Max number of concurrent clients per channel (i.e. TCP connection). New channels will be spun up
/// when full.
@@ -141,8 +140,8 @@ impl PageserverClient {
if !old.count.is_unsharded() && shard_spec.stripe_size != old.stripe_size {
return Err(anyhow!(
"can't change stripe size from {} to {}",
old.stripe_size,
shard_spec.stripe_size
old.stripe_size.expect("always Some when sharded"),
shard_spec.stripe_size.expect("always Some when sharded")
));
}
@@ -157,23 +156,6 @@ impl PageserverClient {
Ok(())
}
/// Returns whether a relation exists.
#[instrument(skip_all, fields(rel=%req.rel, lsn=%req.read_lsn))]
pub async fn check_rel_exists(
&self,
req: page_api::CheckRelExistsRequest,
) -> tonic::Result<page_api::CheckRelExistsResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.check_rel_exists(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
}
/// Returns the total size of a database, as # of bytes.
#[instrument(skip_all, fields(db_oid=%req.db_oid, lsn=%req.read_lsn))]
pub async fn get_db_size(
@@ -249,13 +231,15 @@ impl PageserverClient {
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
return Self::get_page_with_shard(req, shards.get(shard_id)?).await;
}
// Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and
// reassemble the responses.
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size);
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_id, shard_req) in splitter.drain_requests() {
@@ -265,10 +249,14 @@ impl PageserverClient {
}
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter.add_response(shard_id, shard_response)?;
splitter
.add_response(shard_id, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter.get_response()
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
}
/// Fetches pages on the given shard. Does not retry internally.
@@ -396,12 +384,14 @@ pub struct ShardSpec {
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size for these shards.
stripe_size: ShardStripeSize,
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
}
impl ShardSpec {
/// Creates a new shard spec with the given URLs and stripe size. All shards must be given.
/// The stripe size may be omitted for unsharded tenants.
/// The stripe size must be Some for sharded tenants, or None for unsharded tenants.
pub fn new(
urls: HashMap<ShardIndex, String>,
stripe_size: Option<ShardStripeSize>,
@@ -414,11 +404,13 @@ impl ShardSpec {
n => ShardCount::new(n as u8),
};
// Determine the stripe size. It doesn't matter for unsharded tenants.
// Validate the stripe size.
if stripe_size.is_none() && !count.is_unsharded() {
return Err(anyhow!("stripe size must be given for sharded tenants"));
}
let stripe_size = stripe_size.unwrap_or_default();
if stripe_size.is_some() && count.is_unsharded() {
return Err(anyhow!("stripe size can't be given for unsharded tenants"));
}
// Validate the shard spec.
for (shard_id, url) in &urls {
@@ -458,8 +450,10 @@ struct Shards {
///
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size. Only used for sharded tenants.
stripe_size: ShardStripeSize,
/// The stripe size.
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
}
impl Shards {

View File

@@ -1,7 +1,6 @@
mod client;
mod pool;
mod retry;
mod split;
pub use client::{PageserverClient, ShardSpec};
pub use pageserver_api::shard::ShardStripeSize; // used in ShardSpec

View File

@@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()>
let cold_key_range = splitpoint..key_range.end;
for i in 0..cmd.num_records {
let chosen_range = if rand::thread_rng().gen_bool(0.9) {
let chosen_range = if rand::rng().random_bool(0.9) {
&hot_key_range
} else {
&cold_key_range

View File

@@ -300,9 +300,9 @@ impl MockTimeline {
key_range: &Range<Key>,
) -> anyhow::Result<()> {
crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]);
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for _ in 0..num_records {
self.ingest_record(rng.gen_range(key_range.clone()), len);
self.ingest_record(rng.random_range(key_range.clone()), len);
self.wal_ingested += len;
}
Ok(())

View File

@@ -4,7 +4,7 @@ use anyhow::Context;
use clap::Parser;
use pageserver_api::key::Key;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use pageserver_api::shard::{ShardCount, ShardStripeSize};
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize};
#[derive(Parser)]
pub(super) struct DescribeKeyCommand {
@@ -128,7 +128,9 @@ impl DescribeKeyCommand {
// seeing the sharding placement might be confusing, so leave it out unless shard
// count was given.
let stripe_size = stripe_size.map(ShardStripeSize).unwrap_or_default();
let stripe_size = stripe_size
.map(ShardStripeSize)
.unwrap_or(DEFAULT_STRIPE_SIZE);
println!(
"# placement with shard_count: {} and stripe_size: {}:",
shard_count.0, stripe_size.0

View File

@@ -17,11 +17,11 @@
// grpcurl \
// -plaintext \
// -H "neon-tenant-id: 7c4a1f9e3bd6470c8f3e21a65bd2e980" \
// -H "neon-shard-id: 0b10" \
// -H "neon-shard-id: 0000" \
// -H "neon-timeline-id: f08c4e9a2d5f76b1e3a7c2d8910f4b3e" \
// -H "authorization: Bearer $JWT" \
// -d '{"read_lsn": {"request_lsn": 1234567890}, "rel": {"spc_oid": 1663, "db_oid": 1234, "rel_number": 5678, "fork_number": 0}}'
// localhost:51051 page_api.PageService/CheckRelExists
// -d '{"read_lsn": {"request_lsn": 100000000, "not_modified_since_lsn": 1}, "db_oid": 1}' \
// localhost:51051 page_api.PageService/GetDbSize
// ```
//
// TODO: consider adding neon-compute-mode ("primary", "static", "replica").
@@ -38,8 +38,8 @@ package page_api;
import "google/protobuf/timestamp.proto";
service PageService {
// Returns whether a relation exists.
rpc CheckRelExists(CheckRelExistsRequest) returns (CheckRelExistsResponse);
// NB: unlike libpq, there is no CheckRelExists in gRPC, at the compute team's request. Instead,
// use GetRelSize with allow_missing=true to check existence.
// Fetches a base backup.
rpc GetBaseBackup (GetBaseBackupRequest) returns (stream GetBaseBackupResponseChunk);
@@ -97,17 +97,6 @@ message RelTag {
uint32 fork_number = 4;
}
// Checks whether a relation exists, at the given LSN. Only valid on shard 0,
// other shards will error.
message CheckRelExistsRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
}
message CheckRelExistsResponse {
bool exists = 1;
}
// Requests a base backup.
message GetBaseBackupRequest {
// The LSN to fetch the base backup at. 0 or absent means the latest LSN known to the Pageserver.
@@ -260,10 +249,15 @@ enum GetPageStatusCode {
message GetRelSizeRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
// If true, return missing=true for missing relations instead of a NotFound error.
bool allow_missing = 3;
}
message GetRelSizeResponse {
// The number of blocks in the relation.
uint32 num_blocks = 1;
// If allow_missing=true, this is true for missing relations.
bool missing = 2;
}
// Requests an SLRU segment. Only valid on shard 0, other shards will error.

View File

@@ -69,16 +69,6 @@ impl Client {
Ok(Self { inner })
}
/// Returns whether a relation exists.
pub async fn check_rel_exists(
&mut self,
req: CheckRelExistsRequest,
) -> tonic::Result<CheckRelExistsResponse> {
let req = proto::CheckRelExistsRequest::from(req);
let resp = self.inner.check_rel_exists(req).await?.into_inner();
Ok(resp.into())
}
/// Fetches a base backup.
pub async fn get_base_backup(
&mut self,
@@ -114,7 +104,8 @@ impl Client {
Ok(resps.and_then(|resp| ready(GetPageResponse::try_from(resp).map_err(|err| err.into()))))
}
/// Returns the size of a relation, as # of blocks.
/// Returns the size of a relation as # of blocks, or None if allow_missing=true and the
/// relation does not exist.
pub async fn get_rel_size(
&mut self,
req: GetRelSizeRequest,

View File

@@ -19,7 +19,9 @@ pub mod proto {
}
mod client;
pub use client::Client;
mod model;
mod split;
pub use client::Client;
pub use model::*;
pub use split::GetPageSplitter;

View File

@@ -141,50 +141,6 @@ impl From<RelTag> for proto::RelTag {
}
}
/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error.
#[derive(Clone, Copy, Debug)]
pub struct CheckRelExistsRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
type Error = ProtocolError;
fn try_from(pb: proto::CheckRelExistsRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
fn from(pb: proto::CheckRelExistsResponse) -> Self {
pb.exists
}
}
impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
fn from(exists: CheckRelExistsResponse) -> Self {
Self { exists }
}
}
/// Requests a base backup.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
@@ -709,6 +665,8 @@ impl From<GetPageStatusCode> for tonic::Code {
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
/// If true, return missing=true for missing relations instead of a NotFound error.
pub allow_missing: bool,
}
impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
@@ -721,6 +679,7 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
allow_missing: proto.allow_missing,
})
}
}
@@ -730,21 +689,29 @@ impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
allow_missing: request.allow_missing,
}
}
}
pub type GetRelSizeResponse = u32;
/// The size of a relation as number of blocks, or None if `allow_missing=true` and the relation
/// does not exist.
///
/// INVARIANT: never None if `allow_missing=false` (returns `NotFound` error instead).
pub type GetRelSizeResponse = Option<u32>;
impl From<proto::GetRelSizeResponse> for GetRelSizeResponse {
fn from(proto: proto::GetRelSizeResponse) -> Self {
proto.num_blocks
fn from(pb: proto::GetRelSizeResponse) -> Self {
(!pb.missing).then_some(pb.num_blocks)
}
}
impl From<GetRelSizeResponse> for proto::GetRelSizeResponse {
fn from(num_blocks: GetRelSizeResponse) -> Self {
Self { num_blocks }
fn from(resp: GetRelSizeResponse) -> Self {
Self {
num_blocks: resp.unwrap_or_default(),
missing: resp.is_none(),
}
}
}

View File

@@ -1,19 +1,20 @@
use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use crate::model::*;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::{ShardStripeSize, key_to_shard_number};
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use pageserver_api::shard::key_to_shard_number;
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// Split requests by shard index.
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
requests: HashMap<ShardIndex, GetPageRequest>,
/// The response being assembled. Preallocated with empty pages, to be filled in.
response: page_api::GetPageResponse,
response: GetPageResponse,
/// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used
/// to assemble the response pages in the same order as the original request.
block_shards: Vec<ShardIndex>,
@@ -23,45 +24,56 @@ impl GetPageSplitter {
/// Checks if the given request only touches a single shard, and returns the shard ID. This is
/// the common case, so we check first in order to avoid unnecessary allocations and overhead.
pub fn for_single_shard(
req: &page_api::GetPageRequest,
req: &GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Option<ShardIndex> {
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
// Fast path: unsharded tenant.
if count.is_unsharded() {
return Some(ShardIndex::unsharded());
return Ok(Some(ShardIndex::unsharded()));
}
// Find the first page's shard, for comparison. If there are no pages, just return the first
// shard (caller likely checked already, otherwise the server will reject it).
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Find the first page's shard, for comparison.
let Some(&first_page) = req.block_numbers.first() else {
return Some(ShardIndex::new(ShardNumber(0), count));
return Err(anyhow!("no block numbers in request"));
};
let key = rel_block_to_key(req.rel, first_page);
let shard_number = key_to_shard_number(count, stripe_size, &key);
req.block_numbers
Ok(req
.block_numbers
.iter()
.skip(1) // computed above
.all(|&blkno| {
let key = rel_block_to_key(req.rel, blkno);
key_to_shard_number(count, stripe_size, &key) == shard_number
})
.then_some(ShardIndex::new(shard_number, count))
.then_some(ShardIndex::new(shard_number, count)))
}
/// Splits the given request.
pub fn split(
req: page_api::GetPageRequest,
req: GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Self {
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
// The caller should make sure we don't split requests unnecessarily.
debug_assert!(
Self::for_single_shard(&req, count, stripe_size).is_none(),
Self::for_single_shard(&req, count, stripe_size)?.is_none(),
"unnecessary request split"
);
if count.is_unsharded() {
return Err(anyhow!("unsharded tenant, no point in splitting request"));
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Split the requests by shard index.
let mut requests = HashMap::with_capacity(2); // common case
let mut block_shards = Vec::with_capacity(req.block_numbers.len());
@@ -72,7 +84,7 @@ impl GetPageSplitter {
requests
.entry(shard_id)
.or_insert_with(|| page_api::GetPageRequest {
.or_insert_with(|| GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
@@ -86,16 +98,16 @@ impl GetPageSplitter {
// Construct a response to be populated by shard responses. Preallocate empty page slots
// with the expected block numbers.
let response = page_api::GetPageResponse {
let response = GetPageResponse {
request_id: req.request_id,
status_code: page_api::GetPageStatusCode::Ok,
status_code: GetPageStatusCode::Ok,
reason: None,
rel: req.rel,
pages: req
.block_numbers
.into_iter()
.map(|block_number| {
page_api::Page {
Page {
block_number,
image: Bytes::new(), // empty page slot to be filled in
}
@@ -103,17 +115,15 @@ impl GetPageSplitter {
.collect(),
};
Self {
Ok(Self {
requests,
response,
block_shards,
}
})
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
pub fn drain_requests(&mut self) -> impl Iterator<Item = (ShardIndex, GetPageRequest)> {
self.requests.drain()
}
@@ -123,22 +133,31 @@ impl GetPageSplitter {
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: page_api::GetPageResponse,
) -> tonic::Result<()> {
response: GetPageResponse,
) -> anyhow::Result<()> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != page_api::GetPageStatusCode::Ok {
return Err(tonic::Status::internal(format!(
if response.status_code != GetPageStatusCode::Ok {
return Err(anyhow!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
response.reason.unwrap_or_default()
)));
));
}
if response.request_id != self.response.request_id {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id, response.request_id
)));
self.response.request_id,
response.request_id
));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
}
// Place the shard response pages into the assembled response, in request order.
@@ -150,27 +169,26 @@ impl GetPageSplitter {
}
let Some(slot) = self.response.pages.get_mut(i) else {
return Err(tonic::Status::internal(format!(
"no block_shards slot {i} for shard {shard_id}"
)));
return Err(anyhow!("no block_shards slot {i} for shard {shard_id}"));
};
let Some(page) = pages.next() else {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"missing page {} in shard {shard_id} response",
slot.block_number
)));
));
};
if page.block_number != slot.block_number {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned wrong page at index {i}, expected {} got {}",
slot.block_number, page.block_number
)));
slot.block_number,
page.block_number
));
}
if !slot.image.is_empty() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned duplicate page {} at index {i}",
slot.block_number
)));
));
}
*slot = page;
@@ -178,10 +196,10 @@ impl GetPageSplitter {
// Make sure we've consumed all pages from the shard response.
if let Some(extra_page) = pages.next() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned extra page: {}",
extra_page.block_number
)));
));
}
Ok(())
@@ -189,18 +207,18 @@ impl GetPageSplitter {
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> tonic::Result<page_api::GetPageResponse> {
pub fn get_response(self) -> anyhow::Result<GetPageResponse> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"missing page {} for shard {}",
page.block_number,
self.block_shards
.get(i)
.map(|s| s.to_string())
.unwrap_or_else(|| "?".to_string())
)));
));
}
}

View File

@@ -188,9 +188,9 @@ async fn main_impl(
start_work_barrier.wait().await;
loop {
let (timeline, work) = {
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
let target = all_targets.choose(&mut rng).unwrap();
let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
let lsn = target.lsn_range.clone().map(|r| rng.random_range(r));
(target.timeline, Work { lsn })
};
let sender = work_senders.get(&timeline).unwrap();

View File

@@ -354,8 +354,7 @@ async fn main_impl(
.cloned()
.collect();
let weights =
rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len()))
.unwrap();
rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap();
Box::pin(async move {
let scheme = match Url::parse(&args.page_service_connstring) {
@@ -455,7 +454,7 @@ async fn run_worker(
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
weights: rand::distr::weighted::WeightedIndex<i128>,
) {
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
@@ -497,9 +496,9 @@ async fn run_worker(
}
// Pick a random page from a random relation.
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key: i128 = rng.random_range(r.start..r.end);
let (rel_tag, block_no) = key_to_block(key);
let mut blks = VecDeque::with_capacity(batch_size);
@@ -530,7 +529,7 @@ async fn run_worker(
// We assume that the entire batch can fit within the relation.
assert_eq!(blks.len(), batch_size, "incomplete batch");
let req_lsn = if rng.gen_bool(args.req_latest_probability) {
let req_lsn = if rng.random_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn

View File

@@ -7,7 +7,7 @@ use std::time::{Duration, Instant};
use pageserver_api::models::HistoricLayerInfo;
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
use rand::seq::SliceRandom;
use rand::seq::IndexedMutRandom;
use tokio::sync::{OwnedSemaphorePermit, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -260,7 +260,7 @@ async fn timeline_actor(
loop {
let layer_tx = {
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
timeline.layers.choose_mut(&mut rng).expect("no layers")
};
match layer_tx.try_send(permit.take().unwrap()) {

View File

@@ -126,7 +126,6 @@ fn main() -> anyhow::Result<()> {
Some(cfg) => tracing_utils::OtelEnablement::Enabled {
service_name: "pageserver".to_string(),
export_config: (&cfg.export_config).into(),
runtime: *COMPUTE_REQUEST_RUNTIME,
},
None => tracing_utils::OtelEnablement::Disabled,
};

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::net::IpAddr;
use futures::Future;
use pageserver_api::config::NodeMetadata;
@@ -16,7 +17,7 @@ use tokio_util::sync::CancellationToken;
use url::Url;
use utils::generation::Generation;
use utils::id::{NodeId, TimelineId};
use utils::{backoff, failpoint_support};
use utils::{backoff, failpoint_support, ip_address};
use crate::config::PageServerConf;
use crate::virtual_file::on_fatal_io_error;
@@ -27,6 +28,7 @@ pub struct StorageControllerUpcallClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
node_ip_addr: Option<IpAddr>,
cancel: CancellationToken,
}
@@ -40,6 +42,7 @@ pub trait StorageControllerUpcallApi {
fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> impl Future<
Output = Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError>,
> + Send;
@@ -91,11 +94,18 @@ impl StorageControllerUpcallClient {
);
}
// Intentionally panics if we encountered any errors parsing or reading the IP address.
// Note that if the required environment variable is not set, `read_node_ip_addr_from_env` returns `Ok(None)`
// instead of an error.
let node_ip_addr =
ip_address::read_node_ip_addr_from_env().expect("Error reading node IP address.");
Self {
http_client: client.build().expect("Failed to construct HTTP client"),
base_url: url,
node_id: conf.id,
cancel: cancel.clone(),
node_ip_addr,
}
}
@@ -146,6 +156,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
async fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
let url = self
.base_url
@@ -193,8 +204,8 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
listen_http_addr: m.http_host,
listen_http_port: m.http_port,
listen_https_port: m.https_port,
node_ip_addr: self.node_ip_addr,
availability_zone_id: az_id.expect("Checked above"),
node_ip_addr: None,
})
}
Err(e) => {
@@ -217,6 +228,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
let request = ReAttachRequest {
node_id: self.node_id,
register: register.clone(),
empty_local_disk: Some(empty_local_disk),
};
let response: ReAttachResponse = self

View File

@@ -768,6 +768,7 @@ mod test {
async fn re_attach(
&self,
_conf: &PageServerConf,
_empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
unimplemented!()
}

View File

@@ -155,7 +155,7 @@ impl FeatureResolver {
);
let tenant_properties = PerTenantProperties {
remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)),
remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)),
}
.into_posthog_properties();

View File

@@ -16,7 +16,8 @@ use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt as _};
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -35,8 +36,8 @@ use pageserver_api::pagestream_api::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api as page_api;
use pageserver_page_api::proto;
use pageserver_page_api::{self as page_api, GetPageSplitter};
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -443,6 +444,7 @@ impl TimelineHandles {
handles: Default::default(),
}
}
async fn get(
&mut self,
tenant_id: TenantId,
@@ -469,6 +471,13 @@ impl TimelineHandles {
fn tenant_id(&self) -> Option<TenantId> {
self.wrapper.tenant_id.get().copied()
}
/// Returns whether a child shard exists locally for the given shard.
fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
self.wrapper
.tenant_manager
.has_child_shard(tenant_id, shard_index)
}
}
pub(crate) struct TenantManagerWrapper {
@@ -1636,9 +1645,10 @@ impl PageServerHandler {
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
(
vec![
Self::handle_get_nblocks_request(&shard, &req, &ctx)
Self::handle_get_nblocks_request(&shard, &req, false, &ctx)
.instrument(span.clone())
.await
.map(|msg| msg.expect("allow_missing=false"))
.map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
@@ -2303,12 +2313,16 @@ impl PageServerHandler {
Ok(PagestreamExistsResponse { req: *req, exists })
}
/// If `allow_missing` is true, returns None instead of Err on missing relations. Otherwise,
/// never returns None. It is only supported by the gRPC protocol, so we pass it separately to
/// avoid changing the libpq protocol types.
#[instrument(skip_all, fields(shard_id))]
async fn handle_get_nblocks_request(
timeline: &Timeline,
req: &PagestreamNblocksRequest,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<PagestreamNblocksResponse, PageStreamError> {
) -> Result<Option<PagestreamNblocksResponse>, PageStreamError> {
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
timeline,
@@ -2320,20 +2334,25 @@ impl PageServerHandler {
.await?;
let n_blocks = timeline
.get_rel_size(
.get_rel_size_in_reldir(
req.rel,
Version::LsnRange(LsnRange {
effective_lsn: lsn,
request_lsn: req.hdr.request_lsn,
}),
None,
allow_missing,
ctx,
)
.await?;
let Some(n_blocks) = n_blocks else {
return Ok(None);
};
Ok(PagestreamNblocksResponse {
Ok(Some(PagestreamNblocksResponse {
req: *req,
n_blocks,
})
}))
}
#[instrument(skip_all, fields(shard_id))]
@@ -3368,17 +3387,9 @@ impl GrpcPageServiceHandler {
}
}
/// Acquires a timeline handle for the given request.
/// Acquires a timeline handle for the given request. The shard index must match a local shard.
///
/// TODO: during shard splits, the compute may still be sending requests to the parent shard
/// until the entire split is committed and the compute is notified. Consider installing a
/// temporary shard router from the parent to the children while the split is in progress.
///
/// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage
/// the TimelineHandles lifecycle.
///
/// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid
/// the unnecessary overhead.
/// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`].
async fn get_request_timeline(
&self,
req: &tonic::Request<impl Any>,
@@ -3387,11 +3398,62 @@ impl GrpcPageServiceHandler {
let shard_index = *extract::<ShardIndex>(req);
let shard_selector = ShardSelector::Known(shard_index);
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
TimelineHandles::new(self.tenant_manager.clone())
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await
}
/// Acquires a timeline handle for the given request, which must be for shard zero.
///
/// NB: during an ongoing shard split, the compute will keep talking to the parent shard until
/// the split is committed, but the parent shard may have been removed in the meanwhile. In that
/// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`].
///
/// TODO: revamp the split protocol to avoid this child routing.
async fn get_shard_zero_request_timeline(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, tonic::Status> {
let ttid = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
if shard_index.shard_number.0 != 0 {
return Err(tonic::Status::invalid_argument(format!(
"request must use shard zero (requested shard {shard_index})",
)));
}
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
//
// TODO: this does internal retries, which will delay requests during shard splits (we won't
// look for the child until the parent's retries are exhausted). Don't do that.
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
match handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await
{
Ok(timeline) => Ok(timeline),
Err(err) => {
// We may be in the middle of a shard split. Try to find a child shard 0.
if let Ok(timeline) = handles
.get(ttid.tenant_id, ttid.timeline_id, ShardSelector::Zero)
.await
&& timeline.get_shard_index().shard_count > shard_index.shard_count
{
return Ok(timeline);
}
Err(err.into())
}
}
}
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
/// Only errors if the timeline is shutting down.
///
@@ -3423,28 +3485,22 @@ impl GrpcPageServiceHandler {
/// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send
/// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or
/// split them up in the client or server.
#[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))]
#[instrument(skip_all, fields(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
))]
async fn get_page(
ctx: &RequestContext,
timeline: &WeakHandle<TenantManagerTypes>,
req: proto::GetPageRequest,
timeline: Handle<TenantManagerTypes>,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
) -> Result<proto::GetPageResponse, tonic::Status> {
let received_at = Instant::now();
let timeline = timeline.upgrade()?;
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
let req = page_api::GetPageRequest::try_from(req)?;
span_record!(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard
let effective_lsn = PageServerHandler::effective_request_lsn(
&timeline,
@@ -3519,14 +3575,103 @@ impl GrpcPageServiceHandler {
};
}
Ok(resp.into())
Ok(resp)
}
/// Processes a GetPage request when there is a potential shard split in progress. We have to
/// reroute the request any local child shards, and split batch requests that straddle multiple
/// child shards.
///
/// Parent shards are split and removed incrementally, but the compute is only notified once the
/// entire split commits, which can take several minutes. In the meanwhile, the compute will be
/// sending requests to the parent shard.
///
/// TODO: add test infrastructure to provoke this situation frequently and for long periods of
/// time, to properly exercise it.
///
/// TODO: revamp the split protocol to avoid this, e.g.:
/// * Keep the parent shard until the split commits and the compute is notified.
/// * Notify the compute about each subsplit.
/// * Return an error that updates the compute's shard map.
#[instrument(skip_all)]
async fn maybe_split_get_page(
ctx: &RequestContext,
handles: &mut TimelineHandles,
ttid: TenantTimelineId,
parent: ShardIndex,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
// Check the first page to see if we have any child shards at all. Otherwise, the compute is
// just talking to the wrong Pageserver. If the parent has been split, the shard now owning
// the page must have a higher shard count.
let timeline = handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])),
)
.await?;
let shard_id = timeline.get_shard_identity();
if shard_id.count <= parent.shard_count {
return Err(HandleUpgradeError::ShutDown.into()); // emulate original error
}
// Fast path: the request fits in a single shard.
if let Some(shard_index) =
GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
// We got the shard ID from the first page, so these must be equal.
assert_eq!(shard_index.shard_number, shard_id.number);
assert_eq!(shard_index.shard_count, shard_id.count);
return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await;
}
// The request spans multiple shards; split it and dispatch parallel requests. All pages
// were originally in the parent shard, and during a split all children are local, so we
// expect to find local shards for all pages.
let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_index, shard_req) in splitter.drain_requests() {
let timeline = handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await?;
let future = Self::get_page(
ctx,
timeline,
shard_req,
io_concurrency.clone(),
received_at,
)
.map(move |result| result.map(|resp| (shard_index, resp)));
shard_requests.push(future);
}
while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_index, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
}
}
/// Implements the gRPC page service.
///
/// Tonic will drop the request handler futures if the client goes away (e.g. due to a timeout or
/// cancellation), so the read path must be cancellation-safe. On shutdown, Tonic will wait for
/// On client disconnect (e.g. timeout or client shutdown), Tonic will drop the request handler
/// futures, so the read path must be cancellation-safe. On server shutdown, Tonic will wait for
/// in-flight requests to complete.
///
/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code.
@@ -3539,39 +3684,6 @@ impl proto::PageService for GrpcPageServiceHandler {
type GetPagesStream =
Pin<Box<dyn Stream<Item = Result<proto::GetPageResponse, tonic::Status>> + Send>>;
#[instrument(skip_all, fields(rel, lsn))]
async fn check_rel_exists(
&self,
req: tonic::Request<proto::CheckRelExistsRequest>,
) -> Result<tonic::Response<proto::CheckRelExistsResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?;
span_record!(rel=%req.rel, lsn=%req.read_lsn);
let req = PagestreamExistsRequest {
hdr: Self::make_hdr(req.read_lsn, None),
rel: req.rel,
};
// Execute the request and convert the response.
let _timer = Self::record_op_start_and_throttle(
&timeline,
metrics::SmgrQueryType::GetRelExists,
received_at,
)
.await?;
let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?;
let resp: page_api::CheckRelExistsResponse = resp.exists;
Ok(tonic::Response::new(resp.into()))
}
#[instrument(skip_all, fields(lsn))]
async fn get_base_backup(
&self,
@@ -3581,11 +3693,10 @@ impl proto::PageService for GrpcPageServiceHandler {
// to be the sweet spot where throughput is saturated.
const CHUNK_SIZE: usize = 256 * 1024;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate the request and decorate the span.
Self::ensure_shard_zero(&timeline)?;
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
@@ -3701,11 +3812,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn);
@@ -3734,14 +3844,33 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// Extract the timeline from the request and check that it exists.
//
// NB: during shard splits, the compute may still send requests to the parent shard. We'll
// reroute requests to the child shards below, but we also detect the common cases here
// where either the shard exists or no shards exist at all. If we have a child shard, we
// can't acquire a weak handle because we don't know which child shard to use yet.
//
// TODO: TimelineHandles.get() does internal retries, which will delay requests during shard
// splits. It shouldn't.
let ttid = *extract::<TenantTimelineId>(&req);
let shard_index = *extract::<ShardIndex>(&req);
let shard_selector = ShardSelector::Known(shard_index);
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?;
let timeline = match handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await
{
// The timeline shard exists. Keep a weak handle to reuse for each request.
Ok(timeline) => Some(timeline.downgrade()),
// The shard doesn't exist, but a child shard does. We'll reroute requests later.
Err(_) if handles.has_child_shard(ttid.tenant_id, shard_index) => None,
// Failed to fetch the timeline, and no child shard exists. Error out.
Err(err) => return Err(err.into()),
};
// Spawn an IoConcurrency sidecar, if enabled.
let gate_guard = self
@@ -3758,31 +3887,58 @@ impl proto::PageService for GrpcPageServiceHandler {
let mut reqs = req.into_inner();
let resps = async_stream::try_stream! {
let timeline = handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?
.downgrade();
loop {
// Wait for the next client request.
//
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
biased;
_ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")),
result = reqs.message() => match result {
Ok(Some(req)) => Ok(req),
Ok(None) => break, // client closed the stream
Err(err) => Err(err),
},
_ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")),
}?;
let received_at = Instant::now();
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
// Process the request, using a closure to capture errors.
let process_request = async || {
let req = page_api::GetPageRequest::try_from(req)?;
// Fast path: use the pre-acquired timeline handle.
if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) {
return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at)
.instrument(span.clone()) // propagate request span
.await
}
// The timeline handle is stale. During shard splits, the compute may still be
// sending requests to the parent shard. Try to re-route requests to the child
// shards, and split any batch requests that straddle multiple child shards.
Self::maybe_split_get_page(
&ctx,
&mut handles,
ttid,
shard_index,
req,
io_concurrency.clone(),
received_at,
)
.instrument(span.clone()) // propagate request span
.await;
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
.await
};
// Return the response. Convert per-request errors to GetPageResponses if
// appropriate, or terminate the stream with a tonic::Status.
yield match process_request().await {
Ok(resp) => resp.into(),
Err(status) => {
// Log the error, since ObservabilityLayer won't see stream errors.
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
@@ -3796,20 +3952,20 @@ impl proto::PageService for GrpcPageServiceHandler {
Ok(tonic::Response::new(Box::pin(resps)))
}
#[instrument(skip_all, fields(rel, lsn))]
#[instrument(skip_all, fields(rel, lsn, allow_missing))]
async fn get_rel_size(
&self,
req: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
span_record!(rel=%req.rel, lsn=%req.read_lsn);
span_record!(rel=%req.rel, lsn=%req.read_lsn, allow_missing=%req.allow_missing);
let req = PagestreamNblocksRequest {
hdr: Self::make_hdr(req.read_lsn, None),
@@ -3824,8 +3980,11 @@ impl proto::PageService for GrpcPageServiceHandler {
)
.await?;
let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?;
let resp: page_api::GetRelSizeResponse = resp.n_blocks;
let resp =
PageServerHandler::handle_get_nblocks_request(&timeline, &req, allow_missing, &ctx)
.await?;
let resp: page_api::GetRelSizeResponse = resp.map(|resp| resp.n_blocks);
Ok(tonic::Response::new(resp.into()))
}
@@ -3835,7 +3994,7 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
@@ -3869,6 +4028,10 @@ impl proto::PageService for GrpcPageServiceHandler {
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
// TODO: this won't work during shard splits, as the request is directed at a specific shard
// but the parent shard is removed before the split commits and the compute is notified
// (which can take several minutes for large tenants). That's also the case for the libpq
// implementation, so we keep the behavior for now.
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);

View File

@@ -504,8 +504,9 @@ impl Timeline {
for rel in rels {
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), false, ctx)
.await?
.expect("allow_missing=false");
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -521,10 +522,16 @@ impl Timeline {
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
Ok(self
.get_rel_size_in_reldir(tag, version, None, false, ctx)
.await?
.expect("allow_missing=false"))
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
/// Get size of a relation file. If `allow_missing` is true, returns None for missing relations,
/// otherwise errors.
///
/// INVARIANT: never returns None if `allow_missing=false`.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
@@ -532,8 +539,9 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
) -> Result<Option<BlockNumber>, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
RelationError::InvalidRelnode.into(),
@@ -541,7 +549,15 @@ impl Timeline {
}
if let Some(nblocks) = self.get_cached_rel_size(&tag, version) {
return Ok(nblocks);
return Ok(Some(nblocks));
}
if allow_missing
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
{
return Ok(None);
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
@@ -553,7 +569,7 @@ impl Timeline {
// FSM, and smgrnblocks() on it immediately afterwards,
// without extending it. Tolerate that by claiming that
// any non-existent FSM fork has size 0.
return Ok(0);
return Ok(Some(0));
}
let key = rel_size_to_key(tag);
@@ -562,7 +578,7 @@ impl Timeline {
self.update_cached_rel_size(tag, version, nblocks);
Ok(nblocks)
Ok(Some(nblocks))
}
/// Does the relation exist?
@@ -2912,9 +2928,8 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
mod tests {
use hex_literal::hex;
use pageserver_api::models::ShardParameters;
use pageserver_api::shard::ShardStripeSize;
use utils::id::TimelineId;
use utils::shard::{ShardCount, ShardNumber};
use utils::shard::{ShardCount, ShardNumber, ShardStripeSize};
use super::*;
use crate::DEFAULT_PG_VERSION;

View File

@@ -6161,11 +6161,11 @@ mod tests {
use pageserver_api::keyspace::KeySpaceRandomAccum;
use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease};
use pageserver_compaction::helpers::overlaps_with;
use rand::Rng;
#[cfg(feature = "testing")]
use rand::SeedableRng;
#[cfg(feature = "testing")]
use rand::rngs::StdRng;
use rand::{Rng, thread_rng};
#[cfg(feature = "testing")]
use std::ops::Range;
use storage_layer::{IoConcurrency, PersistentLayerKey};
@@ -6286,8 +6286,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -6330,8 +6330,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -7808,7 +7808,7 @@ mod tests {
for _ in 0..50 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7897,7 +7897,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7965,7 +7965,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -8229,7 +8229,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -8502,7 +8502,7 @@ mod tests {
for iter in 1..=10 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -11291,10 +11291,10 @@ mod tests {
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_read_path() -> anyhow::Result<()> {
use rand::seq::SliceRandom;
use rand::seq::IndexedRandom;
let seed = if cfg!(feature = "fuzz-read-path") {
let seed: u64 = thread_rng().r#gen();
let seed: u64 = rand::rng().random();
seed
} else {
// Use a hard-coded seed when not in fuzzing mode.
@@ -11308,8 +11308,8 @@ mod tests {
let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") {
const QUERIES: u64 = 5000;
let will_init_chance: u8 = random.gen_range(0..=10);
let gap_chance: u8 = random.gen_range(0..=50);
let will_init_chance: u8 = random.random_range(0..=10);
let gap_chance: u8 = random.random_range(0..=50);
(QUERIES, will_init_chance, gap_chance)
} else {
@@ -11410,7 +11410,8 @@ mod tests {
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
let mut selected_key =
start_key.add(random.random_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
if used_keys.contains(&selected_key)
@@ -11425,7 +11426,7 @@ mod tests {
.add_key(selected_key);
used_keys.insert(selected_key);
let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE;
let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE;
if pick_next {
selected_key = selected_key.next();
} else {

View File

@@ -535,8 +535,8 @@ pub(crate) mod tests {
}
pub(crate) fn random_array(len: usize) -> Vec<u8> {
let mut rng = rand::thread_rng();
(0..len).map(|_| rng.r#gen()).collect::<_>()
let mut rng = rand::rng();
(0..len).map(|_| rng.random()).collect::<_>()
}
#[tokio::test]
@@ -588,9 +588,9 @@ pub(crate) mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let blobs = (0..1024)
.map(|_| {
let mut sz: u16 = rng.r#gen();
let mut sz: u16 = rng.random();
// Make 50% of the arrays small
if rng.r#gen() {
if rng.random() {
sz &= 63;
}
random_array(sz.into())

View File

@@ -1090,7 +1090,7 @@ pub(crate) mod tests {
const NUM_KEYS: usize = 100000;
let mut all_data: BTreeMap<u128, u64> = BTreeMap::new();
for idx in 0..NUM_KEYS {
let u: f64 = rand::thread_rng().gen_range(0.0..1.0);
let u: f64 = rand::rng().random_range(0.0..1.0);
let t = -(f64::ln(u));
let key_int = (t * 1000000.0) as u128;
@@ -1116,7 +1116,7 @@ pub(crate) mod tests {
// Test get() operations on random keys, most of which will not exist
for _ in 0..100000 {
let key_int = rand::thread_rng().r#gen::<u128>();
let key_int = rand::rng().random::<u128>();
let search_key = u128::to_be_bytes(key_int);
assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned());
}

View File

@@ -508,8 +508,8 @@ mod tests {
let write_nbytes = cap * 2 + cap / 2;
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(write_nbytes)
.collect();
@@ -565,8 +565,8 @@ mod tests {
let cap = writer.mutable().capacity();
drop(writer);
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(cap * 2 + cap / 2)
.collect();
@@ -614,8 +614,8 @@ mod tests {
let cap = mutable.capacity();
let align = mutable.align();
drop(writer);
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(cap * 2 + cap / 2)
.collect();

Some files were not shown because too many files have changed in this diff Show More