Merge remote-tracking branch 'origin/main' into communicator-rewrite

This commit is contained in:
Heikki Linnakangas
2025-07-23 00:00:10 +03:00
190 changed files with 5161 additions and 1877 deletions

View File

@@ -21,13 +21,14 @@ platforms = [
# "x86_64-apple-darwin",
# "x86_64-pc-windows-msvc",
]
[final-excludes]
workspace-members = [
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
"vm_monitor",
# subzero-core is a stub crate that should be excluded from workspace-hack
"subzero-core",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"compute_api",

View File

@@ -0,0 +1,28 @@
name: 'Prepare current job for subzero'
description: >
Set git token to access `neondatabase/subzero` from cargo build,
and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI
inputs:
token:
description: 'GitHub token with access to neondatabase/subzero'
required: true
runs:
using: "composite"
steps:
- name: Set git token for neondatabase/subzero
uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0
env:
SUBZERO_ACCESS_TOKEN: ${{ inputs.token }}
with:
main: |
git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a
post: |
git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
- name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable
shell: bash -euxo pipefail {0}
run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV}

View File

@@ -86,6 +86,10 @@ jobs:
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Set pg 14 revision for caching
id: pg_v14_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT
@@ -116,7 +120,7 @@ jobs:
ARCH: ${{ inputs.arch }}
SANITIZERS: ${{ inputs.sanitizers }}
run: |
CARGO_FLAGS="--locked --features testing"
CARGO_FLAGS="--locked --features testing,rest_broker"
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_PROFILE=""

View File

@@ -46,6 +46,10 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Cache cargo deps
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0

View File

@@ -54,6 +54,10 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Install build dependencies
run: |

View File

@@ -632,6 +632,8 @@ jobs:
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
provenance: false
push: true
pull: true

View File

@@ -72,6 +72,7 @@ jobs:
check-macos-build:
needs: [ check-permissions, files-changed ]
uses: ./.github/workflows/build-macos.yml
secrets: inherit
with:
pg_versions: ${{ needs.files-changed.outputs.postgres_changes }}
rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }}

5
.gitignore vendored
View File

@@ -27,9 +27,14 @@ docker-compose/docker-compose-parallel.yml
*.o
*.so
*.Po
*.pid
# pgindent typedef lists
*.list
# Node
**/node_modules/
# various files for local testing
/proxy/.subzero
local_proxy.json

528
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -50,6 +50,7 @@ members = [
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
"proxy/subzero_core",
]
[workspace.package]
@@ -144,10 +145,10 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
once_cell = "1.13"
opentelemetry = "0.27"
opentelemetry_sdk = "0.27"
opentelemetry-otlp = { version = "0.27", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.27"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.30"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -160,11 +161,13 @@ procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.13.5"
prost-types = "0.13.5"
rand = "0.8"
rand = "0.9"
# Remove after p256 is updated to 0.14.
rand_core = "=0.6"
redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_27"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_30"] }
reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
@@ -214,15 +217,12 @@ tonic = { version = "0.13.1", default-features = false, features = ["channel", "
tonic-reflection = { version = "0.13.1", features = ["server"] }
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }
# This revision uses opentelemetry 0.27. There's no tag for it.
tower-otel = { git = "https://github.com/mattiapenati/tower-otel", rev = "56a7321053bcb72443888257b622ba0d43a11fcd" }
tower-otel = { version = "0.6", features = ["axum"] }
tower-service = "0.3.3"
tracing = "0.1"
tracing-error = "0.2"
tracing-log = "0.2"
tracing-opentelemetry = "0.28"
tracing-opentelemetry = "0.31"
tracing-serde = "0.2.0"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"

View File

@@ -63,7 +63,14 @@ WORKDIR /home/nonroot
COPY --chown=nonroot . .
RUN cargo chef prepare --recipe-path recipe.json
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \
fi \
&& cargo chef prepare --recipe-path recipe.json
# Main build image
FROM $REPOSITORY/$IMAGE:$TAG AS build
@@ -71,20 +78,33 @@ WORKDIR /home/nonroot
ARG GIT_VERSION=local
ARG BUILD_TAG
ARG ADDITIONAL_RUSTFLAGS=""
ENV CARGO_FEATURES="default"
# 3. Build cargo dependencies. Note that this step doesn't depend on anything else than
# `recipe.json`, so the layer can be reused as long as none of the dependencies change.
COPY --from=plan /home/nonroot/recipe.json recipe.json
RUN set -e \
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json
# Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build'
# layer, and the cargo dependencies built in the previous step.
COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install
COPY --chown=nonroot . .
COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml
COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock
RUN set -e \
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \
--bin pagectl \

View File

@@ -27,7 +27,10 @@ fail.workspace = true
flate2.workspace = true
futures.workspace = true
http.workspace = true
http-body-util.workspace = true
hostname-validator = "1.1"
hyper.workspace = true
hyper-util.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
@@ -44,6 +47,7 @@ postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
ring = "0.17"
scopeguard.workspace = true
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true

View File

@@ -138,6 +138,12 @@ struct Cli {
/// Run in development mode, skipping VM-specific operations like process termination
#[arg(long, action = clap::ArgAction::SetTrue)]
pub dev: bool,
#[arg(long)]
pub pg_init_timeout: Option<u64>,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
pub lakebase_mode: bool,
}
impl Cli {
@@ -188,7 +194,7 @@ fn main() -> Result<()> {
.build()?;
let _rt_guard = runtime.enter();
runtime.block_on(init(cli.dev))?;
let tracing_provider = init(cli.dev)?;
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
@@ -219,6 +225,8 @@ fn main() -> Result<()> {
installed_extensions_collection_interval: Arc::new(AtomicU64::new(
cli.installed_extensions_collection_interval,
)),
pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs),
lakebase_mode: cli.lakebase_mode,
},
config,
)?;
@@ -227,11 +235,11 @@ fn main() -> Result<()> {
scenario.teardown();
deinit_and_exit(exit_code);
deinit_and_exit(tracing_provider, exit_code);
}
async fn init(dev_mode: bool) -> Result<()> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL).await?;
fn init(dev_mode: bool) -> Result<Option<tracing_utils::Provider>> {
let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?;
thread::spawn(move || {
@@ -242,7 +250,7 @@ async fn init(dev_mode: bool) -> Result<()> {
info!("compute build_tag: {}", &BUILD_TAG.to_string());
Ok(())
Ok(provider)
}
fn get_config(cli: &Cli) -> Result<ComputeConfig> {
@@ -267,25 +275,27 @@ fn get_config(cli: &Cli) -> Result<ComputeConfig> {
}
}
fn deinit_and_exit(exit_code: Option<i32>) -> ! {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
tracing_utils::shutdown_tracing();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
fn deinit_and_exit(tracing_provider: Option<tracing_utils::Provider>, exit_code: Option<i32>) -> ! {
if let Some(p) = tracing_provider {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
_ = p.shutdown();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
}
}
info!("shutting down");

View File

@@ -0,0 +1,98 @@
//! Client for making request to a running Postgres server's communicator control socket.
//!
//! The storage communicator process that runs inside Postgres exposes an HTTP endpoint in
//! a Unix Domain Socket in the Postgres data directory. This provides access to it.
use std::path::Path;
use anyhow::Context;
use hyper::client::conn::http1::SendRequest;
use hyper_util::rt::TokioIo;
/// Name of the socket within the Postgres data directory. This better match that in
/// `pgxn/neon/communicator/src/lib.rs`.
const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket";
/// Open a connection to the communicator's control socket, prepare to send requests to it
/// with hyper.
pub async fn connect_communicator_socket<B>(pgdata: &Path) -> anyhow::Result<SendRequest<B>>
where
B: hyper::body::Body + 'static + Send,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let socket_path = pgdata.join(NEON_COMMUNICATOR_SOCKET_NAME);
let socket_path_len = socket_path.display().to_string().len();
// There is a limit of around 100 bytes (108 on Linux?) on the length of the path to a
// Unix Domain socket. The limit is on the connect(2) function used to open the
// socket, not on the absolute path itself. Postgres changes the current directory to
// the data directory and uses a relative path to bind to the socket, and the relative
// path "./neon-communicator.socket" is always short, but when compute_ctl needs to
// open the socket, we need to use a full path, which can be arbitrarily long.
//
// There are a few ways we could work around this:
//
// 1. Change the current directory to the Postgres data directory and use a relative
// path in the connect(2) call. That's problematic because the current directory
// applies to the whole process. We could change the current directory early in
// compute_ctl startup, and that might be a good idea anyway for other reasons too:
// it would be more robust if the data directory is moved around or unlinked for
// some reason, and you would be less likely to accidentally litter other parts of
// the filesystem with e.g. temporary files. However, that's a pretty invasive
// change.
//
// 2. On Linux, you could open() the data directory, and refer to the the socket
// inside it as "/proc/self/fd/<fd>/neon-communicator.socket". But that's
// Linux-only.
//
// 3. Create a symbolic link to the socket with a shorter path, and use that.
//
// We use the symbolic link approach here. Hopefully the paths we use in production
// are shorter, so that we can open the socket directly, so that this hack is needed
// only in development.
let connect_result = if socket_path_len < 100 {
// We can open the path directly with no hacks.
tokio::net::UnixStream::connect(socket_path).await
} else {
// The path to the socket is too long. Create a symlink to it with a shorter path.
let short_path = std::env::temp_dir().join(format!(
"compute_ctl.short-socket.{}.{}",
std::process::id(),
tokio::task::id()
));
std::os::unix::fs::symlink(&socket_path, &short_path)?;
// Delete the symlink as soon as we have connected to it. There's a small chance
// of leaking if the process dies before we remove it, so try to keep that window
// as small as possible.
scopeguard::defer! {
if let Err(err) = std::fs::remove_file(&short_path) {
tracing::warn!("could not remove symlink \"{}\" created for socket: {}",
short_path.display(), err);
}
}
tracing::info!(
"created symlink \"{}\" for socket \"{}\", opening it now",
short_path.display(),
socket_path.display()
);
tokio::net::UnixStream::connect(&short_path).await
};
let stream = connect_result.context("connecting to communicator control socket")?;
let io = TokioIo::new(stream);
let (request_sender, connection) = hyper::client::conn::http1::handshake(io).await?;
// spawn a task to poll the connection and drive the HTTP state
tokio::spawn(async move {
if let Err(err) = connection.await {
eprintln!("Error in connection: {err}");
}
});
Ok(request_sender)
}

View File

@@ -114,6 +114,11 @@ pub struct ComputeNodeParams {
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: Arc<AtomicU64>,
/// Timeout of PG compute startup in the Init state.
pub pg_init_timeout: Option<Duration>,
pub lakebase_mode: bool,
}
type TaskHandle = Mutex<Option<JoinHandle<()>>>;
@@ -155,6 +160,7 @@ pub struct RemoteExtensionMetrics {
#[derive(Clone, Debug)]
pub struct ComputeState {
pub start_time: DateTime<Utc>,
pub pg_start_time: Option<DateTime<Utc>>,
pub status: ComputeStatus,
/// Timestamp of the last Postgres activity. It could be `None` if
/// compute wasn't used since start.
@@ -192,6 +198,7 @@ impl ComputeState {
pub fn new() -> Self {
Self {
start_time: Utc::now(),
pg_start_time: None,
status: ComputeStatus::Empty,
last_active: None,
error: None,
@@ -737,6 +744,9 @@ impl ComputeNode {
};
_this_entered = start_compute_span.enter();
// Hadron: Record postgres start time (used to enforce pg_init_timeout).
state_guard.pg_start_time.replace(Utc::now());
state_guard.set_status(ComputeStatus::Init, &self.state_changed);
compute_state = state_guard.clone()
}
@@ -1544,7 +1554,7 @@ impl ComputeNode {
.with_context(|| format!("failed to get basebackup@{lsn}"))?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path)?;
update_pg_hba(pgdata_path, None)?;
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
@@ -1849,6 +1859,7 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let lakebase_mode = self.params.lakebase_mode;
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
@@ -1861,7 +1872,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(params, &mut client, lakebase_mode).await {
error!("Failed to run migrations: {}", e);
}
}

View File

@@ -1,10 +1,18 @@
use std::path::Path;
use std::sync::Arc;
use anyhow::Context;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use hyper::{Request, StatusCode};
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::communicator_socket_client::connect_communicator_socket;
use crate::compute::ComputeNode;
use crate::http::JsonResponse;
use crate::metrics::collect;
@@ -31,3 +39,42 @@ pub(in crate::http) async fn get_metrics() -> Response {
.body(Body::from(buffer))
.unwrap()
}
/// Fetch and forward metrics from the Postgres neon extension's metrics
/// exporter that are used by autoscaling-agent.
///
/// The neon extension exposes these metrics over a Unix domain socket
/// in the data directory. That's not accessible directly from the outside
/// world, so we have this endpoint in compute_ctl to expose it
pub(in crate::http) async fn get_autoscaling_metrics(
State(compute): State<Arc<ComputeNode>>,
) -> Result<Response, Response> {
let pgdata = Path::new(&compute.params.pgdata);
// Connect to the communicator process's metrics socket
let mut metrics_client = connect_communicator_socket(pgdata)
.await
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Make a request for /autoscaling_metrics
let request = Request::builder()
.method("GET")
.uri("/autoscaling_metrics")
.header("Host", "localhost") // hyper requires Host, even though the server won't care
.body(Body::from(""))
.unwrap();
let resp = metrics_client
.send_request(request)
.await
.context("fetching metrics from Postgres metrics service")
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Build a response that just forwards the response we got.
let mut response = Response::builder();
response = response.status(resp.status());
if let Some(content_type) = resp.headers().get(CONTENT_TYPE) {
response = response.header(CONTENT_TYPE, content_type);
}
let body = tonic::service::AxumBody::from_stream(resp.into_body().into_data_stream());
Ok(response.body(body).unwrap())
}

View File

@@ -81,8 +81,12 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Server::External {
config, compute_id, ..
} => {
let unauthenticated_router =
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/metrics", get(metrics::get_metrics))
.route(
"/autoscaling_metrics",
get(metrics::get_autoscaling_metrics),
);
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))

View File

@@ -4,6 +4,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
pub mod checker;
pub mod communicator_socket_client;
pub mod config;
pub mod configurator;
pub mod http;

View File

@@ -13,7 +13,9 @@ use tracing_subscriber::prelude::*;
/// set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`. See
/// `tracing-utils` package description.
///
pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
pub fn init_tracing_and_logging(
default_log_level: &str,
) -> anyhow::Result<Option<tracing_utils::Provider>> {
// Initialize Logging
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level));
@@ -24,8 +26,9 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result
.with_writer(std::io::stderr);
// Initialize OpenTelemetry
let otlp_layer =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()).await;
let provider =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default());
let otlp_layer = provider.as_ref().map(tracing_utils::layer);
// Put it all together
tracing_subscriber::registry()
@@ -37,7 +40,7 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result
utils::logging::replace_panic_hook_with_tracing_panic_hook().forget();
Ok(())
Ok(provider)
}
/// Replace all newline characters with a special character to make it

View File

@@ -9,15 +9,20 @@ use crate::metrics::DB_MIGRATION_FAILED;
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
lakebase_mode: bool,
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
pub fn new(client: &'m mut Client, migrations: &'m [&'m str], lakebase_mode: bool) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
Self { client, migrations }
Self {
client,
migrations,
lakebase_mode,
}
}
/// Get the current value neon_migration.migration_id
@@ -130,8 +135,13 @@ impl<'m> MigrationRunner<'m> {
// ID is also the next index
let migration_id = (current_migration + 1) as i64;
let migration = self.migrations[current_migration];
let migration = if self.lakebase_mode {
migration.replace("neon_superuser", "databricks_superuser")
} else {
migration.to_string()
};
match Self::run_migration(self.client, migration_id, migration).await {
match Self::run_migration(self.client, migration_id, &migration).await {
Ok(_) => {
info!("Finished migration id={}", migration_id);
}

View File

@@ -11,6 +11,7 @@ use tracing::{Level, error, info, instrument, span};
use crate::compute::ComputeNode;
use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS};
const PG_DEFAULT_INIT_TIMEOUIT: Duration = Duration::from_secs(60);
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
/// Struct to store runtime state of the compute monitor thread.
@@ -352,13 +353,47 @@ impl ComputeMonitor {
// Hang on condition variable waiting until the compute status is `Running`.
fn wait_for_postgres_start(compute: &ComputeNode) {
let mut state = compute.state.lock().unwrap();
let pg_init_timeout = compute
.params
.pg_init_timeout
.unwrap_or(PG_DEFAULT_INIT_TIMEOUIT);
while state.status != ComputeStatus::Running {
info!("compute is not running, waiting before monitoring activity");
state = compute.state_changed.wait(state).unwrap();
if !compute.params.lakebase_mode {
state = compute.state_changed.wait(state).unwrap();
if state.status == ComputeStatus::Running {
break;
if state.status == ComputeStatus::Running {
break;
}
continue;
}
if state.pg_start_time.is_some()
&& Utc::now()
.signed_duration_since(state.pg_start_time.unwrap())
.to_std()
.unwrap_or_default()
> pg_init_timeout
{
// If Postgres isn't up and running with working PS/SK connections within POSTGRES_STARTUP_TIMEOUT, it is
// possible that we started Postgres with a wrong spec (so it is talking to the wrong PS/SK nodes). To prevent
// deadends we simply exit (panic) the compute node so it can restart with the latest spec.
//
// NB: We skip this check if we have not attempted to start PG yet (indicated by state.pg_start_up == None).
// This is to make sure the more appropriate errors are surfaced if we encounter issues before we even attempt
// to start PG (e.g., if we can't pull the spec, can't sync safekeepers, or can't get the basebackup).
error!(
"compute did not enter Running state in {} seconds, exiting",
pg_init_timeout.as_secs()
);
std::process::exit(1);
}
state = compute
.state_changed
.wait_timeout(state, Duration::from_secs(5))
.unwrap()
.0;
}
}

View File

@@ -11,7 +11,9 @@ use std::time::{Duration, Instant};
use anyhow::{Result, bail};
use compute_api::responses::TlsConfig;
use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
use compute_api::spec::{
Database, DatabricksSettings, GenericOption, GenericOptions, PgIdent, Role,
};
use futures::StreamExt;
use indexmap::IndexMap;
use ini::Ini;
@@ -184,6 +186,42 @@ impl DatabaseExt for Database {
}
}
pub trait DatabricksSettingsExt {
fn as_pg_settings(&self) -> String;
}
impl DatabricksSettingsExt for DatabricksSettings {
fn as_pg_settings(&self) -> String {
// Postgres GUCs rendered from DatabricksSettings
vec![
// ssl_ca_file
Some(format!(
"ssl_ca_file = '{}'",
self.pg_compute_tls_settings.ca_file
)),
// [Optional] databricks.workspace_url
Some(format!(
"databricks.workspace_url = '{}'",
&self.databricks_workspace_host
)),
// todo(vikas.jain): these are not required anymore as they are moved to static
// conf but keeping these to avoid image mismatch between hcc and pg.
// Once hcc and pg are in sync, we can remove these.
//
// databricks.enable_databricks_identity_login
Some("databricks.enable_databricks_identity_login = true".to_string()),
// databricks.enable_sql_restrictions
Some("databricks.enable_sql_restrictions = true".to_string()),
]
.into_iter()
// Removes `None`s
.flatten()
.collect::<Vec<String>>()
.join("\n")
+ "\n"
}
}
/// Generic trait used to provide quoting / encoding for strings used in the
/// Postgres SQL queries and DATABASE_URL.
pub trait Escaping {

View File

@@ -1,4 +1,6 @@
use std::fs::File;
use std::fs::{self, Permissions};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use anyhow::{Result, anyhow, bail};
@@ -133,10 +135,25 @@ pub fn get_config_from_control_plane(base_uri: &str, compute_id: &str) -> Result
}
/// Check `pg_hba.conf` and update if needed to allow external connections.
pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> {
pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> Result<()> {
// XXX: consider making it a part of config.json
let pghba_path = pgdata_path.join("pg_hba.conf");
// Update pg_hba to contains databricks specfic settings before adding neon settings
// PG uses the first record that matches to perform authentication, so we need to have
// our rules before the default ones from neon.
// See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html
if let Some(databricks_pg_hba) = databricks_pg_hba {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_hba),
)? {
info!("updated pg_hba.conf to include databricks_pg_hba.conf");
} else {
info!("pg_hba.conf already included databricks_pg_hba.conf");
}
}
if config::line_in_file(&pghba_path, PG_HBA_ALL_MD5)? {
info!("updated pg_hba.conf to allow external connections");
} else {
@@ -146,6 +163,59 @@ pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> {
Ok(())
}
/// Check `pg_ident.conf` and update if needed to allow databricks config.
pub fn update_pg_ident(pgdata_path: &Path, databricks_pg_ident: Option<&String>) -> Result<()> {
info!("checking pg_ident.conf");
let pghba_path = pgdata_path.join("pg_ident.conf");
// Update pg_ident to contains databricks specfic settings
if let Some(databricks_pg_ident) = databricks_pg_ident {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_ident),
)? {
info!("updated pg_ident.conf to include databricks_pg_ident.conf");
} else {
info!("pg_ident.conf already included databricks_pg_ident.conf");
}
}
Ok(())
}
/// Copy tls key_file and cert_file from k8s secret mount directory
/// to pgdata and set private key file permissions as expected by Postgres.
/// See this doc for expected permission <https://www.postgresql.org/docs/current/ssl-tcp.html>
/// K8s secrets mount on dblet does not honor permission and ownership
/// specified in the Volume or VolumeMount. So we need to explicitly copy the file and set the permissions.
pub fn copy_tls_certificates(
key_file: &String,
cert_file: &String,
pgdata_path: &Path,
) -> Result<()> {
let files = [cert_file, key_file];
for file in files.iter() {
let source = Path::new(file);
let dest = pgdata_path.join(source.file_name().unwrap());
if !dest.exists() {
std::fs::copy(source, &dest)?;
info!(
"Copying tls file: {} to {}",
&source.display(),
&dest.display()
);
}
if *file == key_file {
// Postgres requires private key to be readable only by the owner by having
// chmod 600 permissions.
let permissions = Permissions::from_mode(0o600);
fs::set_permissions(&dest, permissions)?;
info!("Setting permission on {}.", &dest.display());
}
}
Ok(())
}
/// Create a standby.signal file
pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
// XXX: consider making it a part of config.json
@@ -170,7 +240,11 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(
params: ComputeNodeParams,
client: &mut Client,
lakebase_mode: bool,
) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -234,7 +308,7 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
),
];
MigrationRunner::new(client, &migrations)
MigrationRunner::new(client, &migrations, lakebase_mode)
.run_migrations()
.await?;

View File

@@ -411,7 +411,8 @@ impl ComputeNode {
.map(|limit| match limit {
0..10 => limit,
10..30 => 10,
30.. => limit / 3,
30..300 => limit / 3,
300.. => 100,
})
// If we didn't find max_connections, default to 10 concurrent connections.
.unwrap_or(10)

View File

@@ -411,6 +411,12 @@ struct StorageControllerStartCmdArgs {
help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)"
)]
base_port: Option<u16>,
#[clap(
long,
help = "Whether the storage controller should handle pageserver-reported local disk loss events."
)]
handle_ps_local_disk_loss: Option<bool>,
}
#[derive(clap::Args)]
@@ -1800,6 +1806,7 @@ async fn handle_storage_controller(
instance_id: args.instance_id,
base_port: args.base_port,
start_timeout: args.start_timeout,
handle_ps_local_disk_loss: args.handle_ps_local_disk_loss,
};
if let Err(e) = svc.start(start_args).await {

View File

@@ -728,12 +728,9 @@ impl Endpoint {
// For the sake of backwards-compatibility, also fill in 'pageserver_connstring'
//
// XXX: I believe this is not really needed, except to make
// test_forward_compatibility happy.
//
// Use a closure so that we can conviniently return None in the middle of the
// loop.
let pageserver_connstring = (|| {
let pageserver_connstring: Option<String> = (|| {
let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() {
1
} else {
@@ -749,22 +746,24 @@ impl Endpoint {
.pageserver_conninfo
.shards
.get(&shard_index)
.expect(&format!(
"shard {} not found in pageserver_connection_info",
shard_index
));
.ok_or_else(|| {
anyhow!(
"shard {} not found in pageserver_connection_info",
shard_index
)
})?;
let pageserver = shard
.pageservers
.first()
.expect("must have at least one pageserver");
.ok_or(anyhow!("must have at least one pageserver"))?;
if let Some(libpq_url) = &pageserver.libpq_url {
connstrings.push(libpq_url.clone());
} else {
return None;
return Ok::<_, anyhow::Error>(None);
}
}
Some(connstrings.join(","))
})();
Ok(Some(connstrings.join(",")))
})()?;
// Create config file
let config = {

View File

@@ -56,6 +56,7 @@ pub struct NeonStorageControllerStartArgs {
pub instance_id: u8,
pub base_port: Option<u16>,
pub start_timeout: humantime::Duration,
pub handle_ps_local_disk_loss: Option<bool>,
}
impl NeonStorageControllerStartArgs {
@@ -64,6 +65,7 @@ impl NeonStorageControllerStartArgs {
instance_id: 1,
base_port: None,
start_timeout,
handle_ps_local_disk_loss: None,
}
}
}
@@ -669,6 +671,10 @@ impl StorageController {
println!("Starting storage controller at {scheme}://{host}:{listen_port}");
if start_args.handle_ps_local_disk_loss.unwrap_or_default() {
args.push("--handle-ps-local-disk-loss".to_string());
}
background_process::start_process(
COMMAND,
&instance_dir,

View File

@@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components.
# More documentation for the licenses section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html
[licenses]
version = 2
allow = [
"0BSD",
"Apache-2.0",

View File

@@ -233,7 +233,7 @@ mod tests {
.unwrap()
.as_millis();
use rand::Rng;
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let s3_config = remote_storage::S3Config {
bucket_name: var(REAL_S3_BUCKET).unwrap(),

View File

@@ -460,6 +460,32 @@ pub struct GenericOption {
pub vartype: String,
}
/// Postgres compute TLS settings.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct PgComputeTlsSettings {
// Absolute path to the certificate file for server-side TLS.
pub cert_file: String,
// Absolute path to the private key file for server-side TLS.
pub key_file: String,
// Absolute path to the certificate authority file for verifying client certificates.
pub ca_file: String,
}
/// Databricks specific options for compute instance.
/// This is used to store any other settings that needs to be propagate to Compute
/// but should not be persisted to ComputeSpec in the database.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct DatabricksSettings {
pub pg_compute_tls_settings: PgComputeTlsSettings,
// Absolute file path to databricks_pg_hba.conf file.
pub databricks_pg_hba: String,
// Absolute file path to databricks_pg_ident.conf file.
pub databricks_pg_ident: String,
// Hostname portion of the Databricks workspace URL of the endpoint, or empty string if not known.
// A valid hostname is required for the compute instance to support PAT logins.
pub databricks_workspace_host: String,
}
/// Optional collection of `GenericOption`'s. Type alias allows us to
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;

View File

@@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> {
IdempotencyKey {
now: Utc::now(),
node_id,
nonce: rand::thread_rng().gen_range(0..=9999),
nonce: rand::rng().random_range(0..=9999),
}
}

View File

@@ -41,7 +41,7 @@ impl NodeOs {
/// Generate a random number in range [0, max).
pub fn random(&self, max: u64) -> u64 {
self.internal.rng.lock().gen_range(0..max)
self.internal.rng.lock().random_range(0..max)
}
/// Append a new event to the world event log.

View File

@@ -32,10 +32,10 @@ impl Delay {
/// Generate a random delay in range [min, max]. Return None if the
/// message should be dropped.
pub fn delay(&self, rng: &mut StdRng) -> Option<u64> {
if rng.gen_bool(self.fail_prob) {
if rng.random_bool(self.fail_prob) {
return None;
}
Some(rng.gen_range(self.min..=self.max))
Some(rng.random_range(self.min..=self.max))
}
}

View File

@@ -69,7 +69,7 @@ impl World {
/// Create a new random number generator.
pub fn new_rng(&self) -> StdRng {
let mut rng = self.rng.lock();
StdRng::from_rng(rng.deref_mut()).unwrap()
StdRng::from_rng(rng.deref_mut())
}
/// Create a new node.

View File

@@ -17,5 +17,5 @@ procfs.workspace = true
measured-process.workspace = true
[dev-dependencies]
rand = "0.8"
rand_distr = "0.4.3"
rand.workspace = true
rand_distr = "0.5"

View File

@@ -260,7 +260,7 @@ mod tests {
#[test]
fn test_cardinality_small() {
let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap());
assert_eq!(actual, [46, 30, 32]);
assert!(51.3 < estimate[0] && estimate[0] < 51.4);
@@ -270,7 +270,7 @@ mod tests {
#[test]
fn test_cardinality_medium() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap());
assert_eq!(actual, [2529, 1618, 1629]);
assert!(2309.1 < estimate[0] && estimate[0] < 2309.2);
@@ -280,7 +280,8 @@ mod tests {
#[test]
fn test_cardinality_large() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap());
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap());
assert_eq!(actual, [129077, 79579, 79630]);
assert!(126067.2 < estimate[0] && estimate[0] < 126067.3);
@@ -290,7 +291,7 @@ mod tests {
#[test]
fn test_cardinality_small2() {
let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap());
assert_eq!(actual, [92, 58, 60]);
assert!(116.1 < estimate[0] && estimate[0] < 116.2);
@@ -300,7 +301,7 @@ mod tests {
#[test]
fn test_cardinality_medium2() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap());
assert_eq!(actual, [8201, 5131, 5051]);
assert!(6846.4 < estimate[0] && estimate[0] < 6846.5);
@@ -310,7 +311,8 @@ mod tests {
#[test]
fn test_cardinality_large2() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap());
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap());
assert_eq!(actual, [777847, 482069, 482246]);
assert!(699437.4 < estimate[0] && estimate[0] < 699437.5);

View File

@@ -394,7 +394,7 @@ impl From<&OtelExporterConfig> for tracing_utils::ExportConfig {
tracing_utils::ExportConfig {
endpoint: Some(val.endpoint.clone()),
protocol: val.protocol.into(),
timeout: val.timeout,
timeout: Some(val.timeout),
}
}
}

View File

@@ -596,6 +596,7 @@ pub struct TimelineImportRequest {
pub timeline_id: TimelineId,
pub start_lsn: Lsn,
pub sk_set: Vec<NodeId>,
pub force_upsert: bool,
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]

View File

@@ -981,12 +981,12 @@ mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let key = Key {
field1: rng.r#gen(),
field2: rng.r#gen(),
field3: rng.r#gen(),
field4: rng.r#gen(),
field5: rng.r#gen(),
field6: rng.r#gen(),
field1: rng.random(),
field2: rng.random(),
field3: rng.random(),
field4: rng.random(),
field5: rng.random(),
field6: rng.random(),
};
assert_eq!(key, Key::from_str(&format!("{key}")).unwrap());

View File

@@ -443,9 +443,9 @@ pub struct ImportPgdataIdempotencyKey(pub String);
impl ImportPgdataIdempotencyKey {
pub fn random() -> Self {
use rand::Rng;
use rand::distributions::Alphanumeric;
use rand::distr::Alphanumeric;
Self(
rand::thread_rng()
rand::rng()
.sample_iter(&Alphanumeric)
.take(20)
.map(char::from)

View File

@@ -69,22 +69,6 @@ impl Hash for ShardIdentity {
}
}
/// Stripe size in number of pages
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
impl Default for ShardStripeSize {
fn default() -> Self {
DEFAULT_STRIPE_SIZE
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
/// Layout version: for future upgrades where we might change how the key->shard mapping works
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash, Debug)]
pub struct ShardLayout(u8);

View File

@@ -21,6 +21,14 @@ pub struct ReAttachRequest {
/// if the node already has a node_id set.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub register: Option<NodeRegisterRequest>,
/// Hadron: Optional flag to indicate whether the node is starting with an empty local disk.
/// Will be set to true if the node couldn't find any local tenant data on startup, could be
/// due to the node starting for the first time or due to a local SSD failure/disk wipe event.
/// The flag may be used by the storage controller to update its observed state of the world
/// to make sure that it sends explicit location_config calls to the node following the
/// re-attach request.
pub empty_local_disk: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug)]

View File

@@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData {
}
}
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
use rand::distr::{Distribution, StandardUniform};
impl Distribution<CancelKeyData> for StandardUniform {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.r#gen(),
cancel_key: rng.r#gen(),
backend_pid: rng.random(),
cancel_key: rng.random(),
}
}
}

View File

@@ -155,10 +155,10 @@ pub struct ScramSha256 {
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
let mut v = rng.random_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}

View File

@@ -74,7 +74,6 @@ impl Header {
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
@@ -145,16 +144,7 @@ impl Message {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
NOTIFICATION_RESPONSE_TAG => Message::NotificationResponse(NotificationResponseBody {}),
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
@@ -543,28 +533,7 @@ impl NoticeResponseBody {
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct NotificationResponseBody {}
pub struct ParameterDescriptionBody {
storage: Bytes,

View File

@@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}

View File

@@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
@@ -221,6 +221,18 @@ impl Client {
&mut self.inner
}
pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
let (tx, rx) = mpsc::unbounded_channel();
let notices = RecordNotices { sender: tx, limit };
self.inner
.sender
.send(FrontendMessage::RecordNotices(notices))
.ok();
rx
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(

View File

@@ -3,10 +3,17 @@ use std::io;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use tokio::sync::mpsc::UnboundedSender;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
RecordNotices(RecordNotices),
}
pub struct RecordNotices {
pub sender: UnboundedSender<Box<str>>,
pub limit: usize,
}
pub enum BackendMessage {
@@ -33,14 +40,11 @@ impl FallibleIterator for BackendMessages {
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
impl Encoder<Bytes> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
}
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item);
Ok(())
}
}

View File

@@ -1,11 +1,9 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
@@ -48,8 +46,8 @@ where
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
parameters: _,
delayed_notice: _,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
@@ -72,13 +70,7 @@ where
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))
}

View File

@@ -3,7 +3,7 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::authentication::sasl;
@@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use crate::Error;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
use crate::config::{self, AuthKeys, Config};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::TlsStream;
@@ -25,7 +25,7 @@ pub struct StartupStream<S, T> {
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
impl<S, T> Sink<Bytes> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
@@ -36,7 +36,7 @@ where
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
@@ -120,10 +120,7 @@ where
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -191,10 +188,7 @@ where
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
@@ -253,10 +247,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
@@ -272,10 +263,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,

View File

@@ -1,22 +1,23 @@
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use futures_util::{Sink, Stream, ready};
use postgres_protocol2::message::backend::Message;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, StreamExt, ready};
use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
use tracing::{info, trace};
use tracing::trace;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::Error;
use crate::codec::{
BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
#[derive(PartialEq, Debug)]
enum State {
@@ -33,18 +34,18 @@ enum State {
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
sender: PollSender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
notices: Option<RecordNotices>,
pending_responses: VecDeque<BackendMessage>,
pending_response: Option<BackendMessages>,
state: State,
}
pub enum Never {}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
@@ -52,70 +53,42 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
sender: mpsc::Sender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
sender: PollSender::new(sender),
receiver,
pending_responses,
notices: None,
pending_response: None,
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
}
};
let messages = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
let messages = match self.pending_response.take() {
Some(messages) => messages,
None => {
let message = match self.stream.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
Poll::Ready(Some(Ok(message))) => message,
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
self.handle_notice(body)?;
continue;
}
BackendMessage::Async(_) => continue,
BackendMessage::Normal { messages } => messages,
}
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal { messages } => messages,
};
match self.sender.poll_reserve(cx) {
@@ -126,8 +99,7 @@ where
return Poll::Ready(Err(Error::closed()));
}
Poll::Pending => {
self.pending_responses
.push_back(BackendMessage::Normal { messages });
self.pending_response = Some(messages);
trace!("poll_read: waiting on sender");
return Poll::Pending;
}
@@ -135,6 +107,31 @@ where
}
}
fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
let Some(notices) = &mut self.notices else {
return Ok(());
};
let mut fields = body.fields();
while let Some(field) = fields.next().map_err(Error::parse)? {
// loop until we find the message field
if field.type_() == b'M' {
// if the message field is within the limit, send it.
if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
match notices.sender.send(field.value().into()) {
// set the new limit.
Ok(()) => notices.limit = new_limit,
// closed.
Err(_) => self.notices = None,
}
}
break;
}
}
Ok(())
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
if self.receiver.is_closed() {
@@ -168,21 +165,23 @@ where
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(request)) => {
Poll::Ready(Some(FrontendMessage::Raw(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
self.notices = Some(notices)
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) => {
trace!("poll_write: at eof, terminating");
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.start_send(request.freeze())
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
@@ -231,34 +230,17 @@ where
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(()) = closing {
let Poll::Pending = self.poll_read(cx)?;
if self.poll_write(cx)?.is_ready() {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
// poll_write returned Pending or Ready(()).
// if poll_write returned Ready(()), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
@@ -280,11 +262,9 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
match self.poll_message(cx)? {
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
Poll::Ready(Ok(()))
}
}

View File

@@ -8,7 +8,6 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -93,21 +92,6 @@ impl Notification {
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]

View File

@@ -43,7 +43,7 @@ itertools.workspace = true
sync_wrapper = { workspace = true, features = ["futures"] }
byteorder = "1.4"
rand = "0.8.5"
rand.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -81,7 +81,7 @@ impl UnreliableWrapper {
///
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
let mut attempts = self.attempts.lock().unwrap();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match attempts.entry(op) {
Entry::Occupied(mut e) => {
@@ -94,7 +94,7 @@ impl UnreliableWrapper {
/* BEGIN_HADRON */
// If there are more attempts to fail, fail the request by probability.
if (attempts_before_this < self.attempts_to_fail)
&& (rng.gen_range(0..=100) < self.attempt_failure_probability)
&& (rng.random_range(0..=100) < self.attempt_failure_probability)
{
let error =
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());

View File

@@ -208,7 +208,7 @@ async fn create_azure_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AzureContainer(AzureConfig {

View File

@@ -385,7 +385,7 @@ async fn create_s3_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().r#gen::<u32>();
let random = rand::rng().random::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AwsS3(S3Config {

View File

@@ -1,11 +1,5 @@
//! Helper functions to set up OpenTelemetry tracing.
//!
//! This comes in two variants, depending on whether you have a Tokio runtime available.
//! If you do, call `init_tracing()`. It sets up the trace processor and exporter to use
//! the current tokio runtime. If you don't have a runtime available, or you don't want
//! to share the runtime with the tracing tasks, call `init_tracing_without_runtime()`
//! instead. It sets up a dedicated single-threaded Tokio runtime for the tracing tasks.
//!
//! Example:
//!
//! ```rust,no_run
@@ -21,7 +15,8 @@
//! .with_writer(std::io::stderr);
//!
//! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces
//! let otlp_layer = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default()).await;
//! let provider = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default());
//! let otlp_layer = provider.as_ref().map(tracing_utils::layer);
//!
//! // Put it all together
//! tracing_subscriber::registry()
@@ -36,16 +31,18 @@
pub mod http;
pub mod perf_span;
use opentelemetry::KeyValue;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
pub use opentelemetry_otlp::{ExportConfig, Protocol};
use opentelemetry_sdk::trace::SdkTracerProvider;
use tracing::level_filters::LevelFilter;
use tracing::{Dispatch, Subscriber};
use tracing_subscriber::Layer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
pub type Provider = SdkTracerProvider;
/// Set up OpenTelemetry exporter, using configuration from environment variables.
///
/// `service_name` is set as the OpenTelemetry 'service.name' resource (see
@@ -70,16 +67,7 @@ use tracing_subscriber::registry::LookupSpan;
/// If you need some other setting, please test if it works first. And perhaps
/// add a comment in the list above to save the effort of testing for the next
/// person.
///
/// This doesn't block, but is marked as 'async' to hint that this must be called in
/// asynchronous execution context.
pub async fn init_tracing<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
pub fn init_tracing(service_name: &str, export_config: ExportConfig) -> Option<Provider> {
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -89,52 +77,14 @@ where
))
}
/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing
/// tasks.
pub fn init_tracing_without_runtime<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
pub fn layer<S>(p: &Provider) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
// The opentelemetry batch processor and the OTLP exporter needs a Tokio
// runtime. Create a dedicated runtime for them. One thread should be
// enough.
//
// (Alternatively, instead of batching, we could use the "simple
// processor", which doesn't need Tokio, and use "reqwest-blocking"
// feature for the OTLP exporter, which also doesn't need Tokio. However,
// batching is considered best practice, and also I have the feeling that
// the non-Tokio codepaths in the opentelemetry crate are less used and
// might be more buggy, so better to stay on the well-beaten path.)
//
// We leak the runtime so that it keeps running after we exit the
// function.
let runtime = Box::leak(Box::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("otlp runtime thread")
.worker_threads(1)
.build()
.unwrap(),
));
let _guard = runtime.enter();
Some(init_tracing_internal(
service_name.to_string(),
export_config,
))
tracing_opentelemetry::layer().with_tracer(p.tracer("global"))
}
fn init_tracing_internal<S>(service_name: String, export_config: ExportConfig) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> Provider {
// Sets up exporter from the provided [`ExportConfig`] parameter.
// If the endpoint is not specified, it is loaded from the
// OTEL_EXPORTER_OTLP_ENDPOINT environment variable.
@@ -153,22 +103,14 @@ where
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);
let tracer = opentelemetry_sdk::trace::TracerProvider::builder()
.with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio)
.with_resource(opentelemetry_sdk::Resource::new(vec![KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
service_name,
)]))
Provider::builder()
.with_batch_exporter(exporter)
.with_resource(
opentelemetry_sdk::Resource::builder()
.with_service_name(service_name)
.build(),
)
.build()
.tracer("global");
tracing_opentelemetry::layer().with_tracer(tracer)
}
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit.
pub fn shutdown_tracing() {
opentelemetry::global::shutdown_tracer_provider();
}
pub enum OtelEnablement {
@@ -176,17 +118,17 @@ pub enum OtelEnablement {
Enabled {
service_name: String,
export_config: ExportConfig,
runtime: &'static tokio::runtime::Runtime,
},
}
pub struct OtelGuard {
provider: Provider,
pub dispatch: Dispatch,
}
impl Drop for OtelGuard {
fn drop(&mut self) {
shutdown_tracing();
_ = self.provider.shutdown();
}
}
@@ -199,22 +141,19 @@ impl Drop for OtelGuard {
/// The lifetime of the guard should match taht of the application. On drop, it tears down the
/// OTEL infra.
pub fn init_performance_tracing(otel_enablement: OtelEnablement) -> Option<OtelGuard> {
let otel_subscriber = match otel_enablement {
match otel_enablement {
OtelEnablement::Disabled => None,
OtelEnablement::Enabled {
service_name,
export_config,
runtime,
} => {
let otel_layer = runtime
.block_on(init_tracing(&service_name, export_config))
.with_filter(LevelFilter::INFO);
let provider = init_tracing(&service_name, export_config)?;
let otel_layer = layer(&provider).with_filter(LevelFilter::INFO);
let otel_subscriber = tracing_subscriber::registry().with(otel_layer);
let otel_dispatch = Dispatch::new(otel_subscriber);
let dispatch = Dispatch::new(otel_subscriber);
Some(otel_dispatch)
Some(OtelGuard { dispatch, provider })
}
};
otel_subscriber.map(|dispatch| OtelGuard { dispatch })
}
}

View File

@@ -104,7 +104,7 @@ impl Id {
pub fn generate() -> Self {
let mut tli_buf = [0u8; 16];
rand::thread_rng().fill(&mut tli_buf);
rand::rng().fill(&mut tli_buf);
Id::from(tli_buf)
}

View File

@@ -364,42 +364,37 @@ impl MonotonicCounter<Lsn> for RecordLsn {
}
}
/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s.
/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s.
///
/// This is used by the `pagebench` pageserver benchmarking tool.
pub struct LsnSampler(<u64 as rand::distributions::uniform::SampleUniform>::Sampler);
pub struct LsnSampler(<u64 as rand::distr::uniform::SampleUniform>::Sampler);
impl rand::distributions::uniform::SampleUniform for Lsn {
impl rand::distr::uniform::SampleUniform for Lsn {
type Sampler = LsnSampler;
}
impl rand::distributions::uniform::UniformSampler for LsnSampler {
impl rand::distr::uniform::UniformSampler for LsnSampler {
type X = Lsn;
fn new<B1, B2>(low: B1, high: B2) -> Self
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new(
low.borrow().0,
high.borrow().0,
),
)
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new(low.borrow().0, high.borrow().0)
.map(Self)
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
),
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
)
.map(Self)
}
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> Self::X {

View File

@@ -25,6 +25,12 @@ pub struct ShardIndex {
pub shard_count: ShardCount,
}
/// Stripe size as number of pages.
///
/// NB: don't implement Default, so callers don't lazily use it by mistake. See DEFAULT_STRIPE_SIZE.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
/// Formatting helper, for generating the `shard_id` label in traces.
pub struct ShardSlug<'a>(&'a TenantShardId);
@@ -181,6 +187,12 @@ impl std::fmt::Display for ShardCount {
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for ShardSlug<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(

View File

@@ -11,7 +11,8 @@ use pageserver::tenant::layer_map::LayerMap;
use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc};
use pageserver_api::key::Key;
use pageserver_api::shard::TenantShardId;
use rand::prelude::{SeedableRng, SliceRandom, StdRng};
use rand::prelude::{SeedableRng, StdRng};
use rand::seq::IndexedRandom;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;

View File

@@ -16,10 +16,9 @@ use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool}
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_api::shard::ShardStripeSize;
use pageserver_page_api as page_api;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};
/// Max number of concurrent clients per channel (i.e. TCP connection). New channels will be spun up
/// when full.
@@ -141,8 +140,8 @@ impl PageserverClient {
if !old.count.is_unsharded() && shard_spec.stripe_size != old.stripe_size {
return Err(anyhow!(
"can't change stripe size from {} to {}",
old.stripe_size,
shard_spec.stripe_size
old.stripe_size.expect("always Some when sharded"),
shard_spec.stripe_size.expect("always Some when sharded")
));
}
@@ -157,23 +156,6 @@ impl PageserverClient {
Ok(())
}
/// Returns whether a relation exists.
#[instrument(skip_all, fields(rel=%req.rel, lsn=%req.read_lsn))]
pub async fn check_rel_exists(
&self,
req: page_api::CheckRelExistsRequest,
) -> tonic::Result<page_api::CheckRelExistsResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.check_rel_exists(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
}
/// Returns the total size of a database, as # of bytes.
#[instrument(skip_all, fields(db_oid=%req.db_oid, lsn=%req.read_lsn))]
pub async fn get_db_size(
@@ -249,13 +231,15 @@ impl PageserverClient {
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
return Self::get_page_with_shard(req, shards.get(shard_id)?).await;
}
// Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and
// reassemble the responses.
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size);
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_id, shard_req) in splitter.drain_requests() {
@@ -265,10 +249,14 @@ impl PageserverClient {
}
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter.add_response(shard_id, shard_response)?;
splitter
.add_response(shard_id, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter.get_response()
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
}
/// Fetches pages on the given shard. Does not retry internally.
@@ -396,12 +384,14 @@ pub struct ShardSpec {
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size for these shards.
stripe_size: ShardStripeSize,
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
}
impl ShardSpec {
/// Creates a new shard spec with the given URLs and stripe size. All shards must be given.
/// The stripe size may be omitted for unsharded tenants.
/// The stripe size must be Some for sharded tenants, or None for unsharded tenants.
pub fn new(
urls: HashMap<ShardIndex, String>,
stripe_size: Option<ShardStripeSize>,
@@ -414,11 +404,13 @@ impl ShardSpec {
n => ShardCount::new(n as u8),
};
// Determine the stripe size. It doesn't matter for unsharded tenants.
// Validate the stripe size.
if stripe_size.is_none() && !count.is_unsharded() {
return Err(anyhow!("stripe size must be given for sharded tenants"));
}
let stripe_size = stripe_size.unwrap_or_default();
if stripe_size.is_some() && count.is_unsharded() {
return Err(anyhow!("stripe size can't be given for unsharded tenants"));
}
// Validate the shard spec.
for (shard_id, url) in &urls {
@@ -458,8 +450,10 @@ struct Shards {
///
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size. Only used for sharded tenants.
stripe_size: ShardStripeSize,
/// The stripe size.
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
}
impl Shards {

View File

@@ -1,11 +1,12 @@
use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::{ShardStripeSize, key_to_shard_number};
use pageserver_api::shard::key_to_shard_number;
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
@@ -25,43 +26,54 @@ impl GetPageSplitter {
pub fn for_single_shard(
req: &page_api::GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Option<ShardIndex> {
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
// Fast path: unsharded tenant.
if count.is_unsharded() {
return Some(ShardIndex::unsharded());
return Ok(Some(ShardIndex::unsharded()));
}
// Find the first page's shard, for comparison. If there are no pages, just return the first
// shard (caller likely checked already, otherwise the server will reject it).
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Find the first page's shard, for comparison.
let Some(&first_page) = req.block_numbers.first() else {
return Some(ShardIndex::new(ShardNumber(0), count));
return Err(anyhow!("no block numbers in request"));
};
let key = rel_block_to_key(req.rel, first_page);
let shard_number = key_to_shard_number(count, stripe_size, &key);
req.block_numbers
Ok(req
.block_numbers
.iter()
.skip(1) // computed above
.all(|&blkno| {
let key = rel_block_to_key(req.rel, blkno);
key_to_shard_number(count, stripe_size, &key) == shard_number
})
.then_some(ShardIndex::new(shard_number, count))
.then_some(ShardIndex::new(shard_number, count)))
}
/// Splits the given request.
pub fn split(
req: page_api::GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Self {
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
// The caller should make sure we don't split requests unnecessarily.
debug_assert!(
Self::for_single_shard(&req, count, stripe_size).is_none(),
Self::for_single_shard(&req, count, stripe_size)?.is_none(),
"unnecessary request split"
);
if count.is_unsharded() {
return Err(anyhow!("unsharded tenant, no point in splitting request"));
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Split the requests by shard index.
let mut requests = HashMap::with_capacity(2); // common case
let mut block_shards = Vec::with_capacity(req.block_numbers.len());
@@ -103,11 +115,11 @@ impl GetPageSplitter {
.collect(),
};
Self {
Ok(Self {
requests,
response,
block_shards,
}
})
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
@@ -124,21 +136,30 @@ impl GetPageSplitter {
&mut self,
shard_id: ShardIndex,
response: page_api::GetPageResponse,
) -> tonic::Result<()> {
) -> anyhow::Result<()> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != page_api::GetPageStatusCode::Ok {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
response.reason.unwrap_or_default()
)));
));
}
if response.request_id != self.response.request_id {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id, response.request_id
)));
self.response.request_id,
response.request_id
));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
}
// Place the shard response pages into the assembled response, in request order.
@@ -150,27 +171,26 @@ impl GetPageSplitter {
}
let Some(slot) = self.response.pages.get_mut(i) else {
return Err(tonic::Status::internal(format!(
"no block_shards slot {i} for shard {shard_id}"
)));
return Err(anyhow!("no block_shards slot {i} for shard {shard_id}"));
};
let Some(page) = pages.next() else {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"missing page {} in shard {shard_id} response",
slot.block_number
)));
));
};
if page.block_number != slot.block_number {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned wrong page at index {i}, expected {} got {}",
slot.block_number, page.block_number
)));
slot.block_number,
page.block_number
));
}
if !slot.image.is_empty() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned duplicate page {} at index {i}",
slot.block_number
)));
));
}
*slot = page;
@@ -178,10 +198,10 @@ impl GetPageSplitter {
// Make sure we've consumed all pages from the shard response.
if let Some(extra_page) = pages.next() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"shard {shard_id} returned extra page: {}",
extra_page.block_number
)));
));
}
Ok(())
@@ -189,18 +209,18 @@ impl GetPageSplitter {
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> tonic::Result<page_api::GetPageResponse> {
pub fn get_response(self) -> anyhow::Result<page_api::GetPageResponse> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {
return Err(tonic::Status::internal(format!(
return Err(anyhow!(
"missing page {} for shard {}",
page.block_number,
self.block_shards
.get(i)
.map(|s| s.to_string())
.unwrap_or_else(|| "?".to_string())
)));
));
}
}

View File

@@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()>
let cold_key_range = splitpoint..key_range.end;
for i in 0..cmd.num_records {
let chosen_range = if rand::thread_rng().gen_bool(0.9) {
let chosen_range = if rand::rng().random_bool(0.9) {
&hot_key_range
} else {
&cold_key_range

View File

@@ -300,9 +300,9 @@ impl MockTimeline {
key_range: &Range<Key>,
) -> anyhow::Result<()> {
crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]);
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for _ in 0..num_records {
self.ingest_record(rng.gen_range(key_range.clone()), len);
self.ingest_record(rng.random_range(key_range.clone()), len);
self.wal_ingested += len;
}
Ok(())

View File

@@ -4,7 +4,7 @@ use anyhow::Context;
use clap::Parser;
use pageserver_api::key::Key;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use pageserver_api::shard::{ShardCount, ShardStripeSize};
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize};
#[derive(Parser)]
pub(super) struct DescribeKeyCommand {
@@ -128,7 +128,9 @@ impl DescribeKeyCommand {
// seeing the sharding placement might be confusing, so leave it out unless shard
// count was given.
let stripe_size = stripe_size.map(ShardStripeSize).unwrap_or_default();
let stripe_size = stripe_size
.map(ShardStripeSize)
.unwrap_or(DEFAULT_STRIPE_SIZE);
println!(
"# placement with shard_count: {} and stripe_size: {}:",
shard_count.0, stripe_size.0

View File

@@ -17,11 +17,11 @@
// grpcurl \
// -plaintext \
// -H "neon-tenant-id: 7c4a1f9e3bd6470c8f3e21a65bd2e980" \
// -H "neon-shard-id: 0b10" \
// -H "neon-shard-id: 0000" \
// -H "neon-timeline-id: f08c4e9a2d5f76b1e3a7c2d8910f4b3e" \
// -H "authorization: Bearer $JWT" \
// -d '{"read_lsn": {"request_lsn": 1234567890}, "rel": {"spc_oid": 1663, "db_oid": 1234, "rel_number": 5678, "fork_number": 0}}'
// localhost:51051 page_api.PageService/CheckRelExists
// -d '{"read_lsn": {"request_lsn": 100000000, "not_modified_since_lsn": 1}, "db_oid": 1}' \
// localhost:51051 page_api.PageService/GetDbSize
// ```
//
// TODO: consider adding neon-compute-mode ("primary", "static", "replica").
@@ -38,8 +38,8 @@ package page_api;
import "google/protobuf/timestamp.proto";
service PageService {
// Returns whether a relation exists.
rpc CheckRelExists(CheckRelExistsRequest) returns (CheckRelExistsResponse);
// NB: unlike libpq, there is no CheckRelExists in gRPC, at the compute team's request. Instead,
// use GetRelSize with allow_missing=true to check existence.
// Fetches a base backup.
rpc GetBaseBackup (GetBaseBackupRequest) returns (stream GetBaseBackupResponseChunk);
@@ -97,17 +97,6 @@ message RelTag {
uint32 fork_number = 4;
}
// Checks whether a relation exists, at the given LSN. Only valid on shard 0,
// other shards will error.
message CheckRelExistsRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
}
message CheckRelExistsResponse {
bool exists = 1;
}
// Requests a base backup.
message GetBaseBackupRequest {
// The LSN to fetch the base backup at. 0 or absent means the latest LSN known to the Pageserver.
@@ -260,10 +249,15 @@ enum GetPageStatusCode {
message GetRelSizeRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
// If true, return missing=true for missing relations instead of a NotFound error.
bool allow_missing = 3;
}
message GetRelSizeResponse {
// The number of blocks in the relation.
uint32 num_blocks = 1;
// If allow_missing=true, this is true for missing relations.
bool missing = 2;
}
// Requests an SLRU segment. Only valid on shard 0, other shards will error.

View File

@@ -69,16 +69,6 @@ impl Client {
Ok(Self { inner })
}
/// Returns whether a relation exists.
pub async fn check_rel_exists(
&mut self,
req: CheckRelExistsRequest,
) -> tonic::Result<CheckRelExistsResponse> {
let req = proto::CheckRelExistsRequest::from(req);
let resp = self.inner.check_rel_exists(req).await?.into_inner();
Ok(resp.into())
}
/// Fetches a base backup.
pub async fn get_base_backup(
&mut self,
@@ -114,7 +104,8 @@ impl Client {
Ok(resps.and_then(|resp| ready(GetPageResponse::try_from(resp).map_err(|err| err.into()))))
}
/// Returns the size of a relation, as # of blocks.
/// Returns the size of a relation as # of blocks, or None if allow_missing=true and the
/// relation does not exist.
pub async fn get_rel_size(
&mut self,
req: GetRelSizeRequest,

View File

@@ -141,50 +141,6 @@ impl From<RelTag> for proto::RelTag {
}
}
/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error.
#[derive(Clone, Copy, Debug)]
pub struct CheckRelExistsRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
type Error = ProtocolError;
fn try_from(pb: proto::CheckRelExistsRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
fn from(pb: proto::CheckRelExistsResponse) -> Self {
pb.exists
}
}
impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
fn from(exists: CheckRelExistsResponse) -> Self {
Self { exists }
}
}
/// Requests a base backup.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
@@ -709,6 +665,8 @@ impl From<GetPageStatusCode> for tonic::Code {
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
/// If true, return missing=true for missing relations instead of a NotFound error.
pub allow_missing: bool,
}
impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
@@ -721,6 +679,7 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
allow_missing: proto.allow_missing,
})
}
}
@@ -730,21 +689,29 @@ impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
allow_missing: request.allow_missing,
}
}
}
pub type GetRelSizeResponse = u32;
/// The size of a relation as number of blocks, or None if `allow_missing=true` and the relation
/// does not exist.
///
/// INVARIANT: never None if `allow_missing=false` (returns `NotFound` error instead).
pub type GetRelSizeResponse = Option<u32>;
impl From<proto::GetRelSizeResponse> for GetRelSizeResponse {
fn from(proto: proto::GetRelSizeResponse) -> Self {
proto.num_blocks
fn from(pb: proto::GetRelSizeResponse) -> Self {
(!pb.missing).then_some(pb.num_blocks)
}
}
impl From<GetRelSizeResponse> for proto::GetRelSizeResponse {
fn from(num_blocks: GetRelSizeResponse) -> Self {
Self { num_blocks }
fn from(resp: GetRelSizeResponse) -> Self {
Self {
num_blocks: resp.unwrap_or_default(),
missing: resp.is_none(),
}
}
}

View File

@@ -188,9 +188,9 @@ async fn main_impl(
start_work_barrier.wait().await;
loop {
let (timeline, work) = {
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
let target = all_targets.choose(&mut rng).unwrap();
let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
let lsn = target.lsn_range.clone().map(|r| rng.random_range(r));
(target.timeline, Work { lsn })
};
let sender = work_senders.get(&timeline).unwrap();

View File

@@ -354,8 +354,7 @@ async fn main_impl(
.cloned()
.collect();
let weights =
rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len()))
.unwrap();
rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap();
Box::pin(async move {
let scheme = match Url::parse(&args.page_service_connstring) {
@@ -455,7 +454,7 @@ async fn run_worker(
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
weights: rand::distr::weighted::WeightedIndex<i128>,
) {
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
@@ -497,9 +496,9 @@ async fn run_worker(
}
// Pick a random page from a random relation.
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key: i128 = rng.random_range(r.start..r.end);
let (rel_tag, block_no) = key_to_block(key);
let mut blks = VecDeque::with_capacity(batch_size);
@@ -530,7 +529,7 @@ async fn run_worker(
// We assume that the entire batch can fit within the relation.
assert_eq!(blks.len(), batch_size, "incomplete batch");
let req_lsn = if rng.gen_bool(args.req_latest_probability) {
let req_lsn = if rng.random_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn

View File

@@ -7,7 +7,7 @@ use std::time::{Duration, Instant};
use pageserver_api::models::HistoricLayerInfo;
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
use rand::seq::SliceRandom;
use rand::seq::IndexedMutRandom;
use tokio::sync::{OwnedSemaphorePermit, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -260,7 +260,7 @@ async fn timeline_actor(
loop {
let layer_tx = {
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
timeline.layers.choose_mut(&mut rng).expect("no layers")
};
match layer_tx.try_send(permit.take().unwrap()) {

View File

@@ -126,7 +126,6 @@ fn main() -> anyhow::Result<()> {
Some(cfg) => tracing_utils::OtelEnablement::Enabled {
service_name: "pageserver".to_string(),
export_config: (&cfg.export_config).into(),
runtime: *COMPUTE_REQUEST_RUNTIME,
},
None => tracing_utils::OtelEnablement::Disabled,
};

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::net::IpAddr;
use futures::Future;
use pageserver_api::config::NodeMetadata;
@@ -16,7 +17,7 @@ use tokio_util::sync::CancellationToken;
use url::Url;
use utils::generation::Generation;
use utils::id::{NodeId, TimelineId};
use utils::{backoff, failpoint_support};
use utils::{backoff, failpoint_support, ip_address};
use crate::config::PageServerConf;
use crate::virtual_file::on_fatal_io_error;
@@ -27,6 +28,7 @@ pub struct StorageControllerUpcallClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
node_ip_addr: Option<IpAddr>,
cancel: CancellationToken,
}
@@ -40,6 +42,7 @@ pub trait StorageControllerUpcallApi {
fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> impl Future<
Output = Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError>,
> + Send;
@@ -91,11 +94,18 @@ impl StorageControllerUpcallClient {
);
}
// Intentionally panics if we encountered any errors parsing or reading the IP address.
// Note that if the required environment variable is not set, `read_node_ip_addr_from_env` returns `Ok(None)`
// instead of an error.
let node_ip_addr =
ip_address::read_node_ip_addr_from_env().expect("Error reading node IP address.");
Self {
http_client: client.build().expect("Failed to construct HTTP client"),
base_url: url,
node_id: conf.id,
cancel: cancel.clone(),
node_ip_addr,
}
}
@@ -146,6 +156,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
async fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
let url = self
.base_url
@@ -193,8 +204,8 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
listen_http_addr: m.http_host,
listen_http_port: m.http_port,
listen_https_port: m.https_port,
node_ip_addr: self.node_ip_addr,
availability_zone_id: az_id.expect("Checked above"),
node_ip_addr: None,
})
}
Err(e) => {
@@ -217,6 +228,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
let request = ReAttachRequest {
node_id: self.node_id,
register: register.clone(),
empty_local_disk: Some(empty_local_disk),
};
let response: ReAttachResponse = self

View File

@@ -768,6 +768,7 @@ mod test {
async fn re_attach(
&self,
_conf: &PageServerConf,
_empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
unimplemented!()
}

View File

@@ -155,7 +155,7 @@ impl FeatureResolver {
);
let tenant_properties = PerTenantProperties {
remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)),
remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)),
}
.into_posthog_properties();

View File

@@ -1636,9 +1636,10 @@ impl PageServerHandler {
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
(
vec![
Self::handle_get_nblocks_request(&shard, &req, &ctx)
Self::handle_get_nblocks_request(&shard, &req, false, &ctx)
.instrument(span.clone())
.await
.map(|msg| msg.expect("allow_missing=false"))
.map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
@@ -2303,12 +2304,16 @@ impl PageServerHandler {
Ok(PagestreamExistsResponse { req: *req, exists })
}
/// If `allow_missing` is true, returns None instead of Err on missing relations. Otherwise,
/// never returns None. It is only supported by the gRPC protocol, so we pass it separately to
/// avoid changing the libpq protocol types.
#[instrument(skip_all, fields(shard_id))]
async fn handle_get_nblocks_request(
timeline: &Timeline,
req: &PagestreamNblocksRequest,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<PagestreamNblocksResponse, PageStreamError> {
) -> Result<Option<PagestreamNblocksResponse>, PageStreamError> {
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
timeline,
@@ -2320,20 +2325,25 @@ impl PageServerHandler {
.await?;
let n_blocks = timeline
.get_rel_size(
.get_rel_size_in_reldir(
req.rel,
Version::LsnRange(LsnRange {
effective_lsn: lsn,
request_lsn: req.hdr.request_lsn,
}),
None,
allow_missing,
ctx,
)
.await?;
let Some(n_blocks) = n_blocks else {
return Ok(None);
};
Ok(PagestreamNblocksResponse {
Ok(Some(PagestreamNblocksResponse {
req: *req,
n_blocks,
})
}))
}
#[instrument(skip_all, fields(shard_id))]
@@ -3525,8 +3535,8 @@ impl GrpcPageServiceHandler {
/// Implements the gRPC page service.
///
/// Tonic will drop the request handler futures if the client goes away (e.g. due to a timeout or
/// cancellation), so the read path must be cancellation-safe. On shutdown, Tonic will wait for
/// On client disconnect (e.g. timeout or client shutdown), Tonic will drop the request handler
/// futures, so the read path must be cancellation-safe. On server shutdown, Tonic will wait for
/// in-flight requests to complete.
///
/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code.
@@ -3539,39 +3549,6 @@ impl proto::PageService for GrpcPageServiceHandler {
type GetPagesStream =
Pin<Box<dyn Stream<Item = Result<proto::GetPageResponse, tonic::Status>> + Send>>;
#[instrument(skip_all, fields(rel, lsn))]
async fn check_rel_exists(
&self,
req: tonic::Request<proto::CheckRelExistsRequest>,
) -> Result<tonic::Response<proto::CheckRelExistsResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?;
span_record!(rel=%req.rel, lsn=%req.read_lsn);
let req = PagestreamExistsRequest {
hdr: Self::make_hdr(req.read_lsn, None),
rel: req.rel,
};
// Execute the request and convert the response.
let _timer = Self::record_op_start_and_throttle(
&timeline,
metrics::SmgrQueryType::GetRelExists,
received_at,
)
.await?;
let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?;
let resp: page_api::CheckRelExistsResponse = resp.exists;
Ok(tonic::Response::new(resp.into()))
}
#[instrument(skip_all, fields(lsn))]
async fn get_base_backup(
&self,
@@ -3766,12 +3743,14 @@ impl proto::PageService for GrpcPageServiceHandler {
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
biased;
_ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")),
result = reqs.message() => match result {
Ok(Some(req)) => Ok(req),
Ok(None) => break, // client closed the stream
Err(err) => Err(err),
},
_ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")),
}?;
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
@@ -3796,7 +3775,7 @@ impl proto::PageService for GrpcPageServiceHandler {
Ok(tonic::Response::new(Box::pin(resps)))
}
#[instrument(skip_all, fields(rel, lsn))]
#[instrument(skip_all, fields(rel, lsn, allow_missing))]
async fn get_rel_size(
&self,
req: tonic::Request<proto::GetRelSizeRequest>,
@@ -3808,8 +3787,9 @@ impl proto::PageService for GrpcPageServiceHandler {
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
span_record!(rel=%req.rel, lsn=%req.read_lsn);
span_record!(rel=%req.rel, lsn=%req.read_lsn, allow_missing=%req.allow_missing);
let req = PagestreamNblocksRequest {
hdr: Self::make_hdr(req.read_lsn, None),
@@ -3824,8 +3804,11 @@ impl proto::PageService for GrpcPageServiceHandler {
)
.await?;
let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?;
let resp: page_api::GetRelSizeResponse = resp.n_blocks;
let resp =
PageServerHandler::handle_get_nblocks_request(&timeline, &req, allow_missing, &ctx)
.await?;
let resp: page_api::GetRelSizeResponse = resp.map(|resp| resp.n_blocks);
Ok(tonic::Response::new(resp.into()))
}

View File

@@ -504,8 +504,9 @@ impl Timeline {
for rel in rels {
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), false, ctx)
.await?
.expect("allow_missing=false");
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -521,10 +522,16 @@ impl Timeline {
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
Ok(self
.get_rel_size_in_reldir(tag, version, None, false, ctx)
.await?
.expect("allow_missing=false"))
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
/// Get size of a relation file. If `allow_missing` is true, returns None for missing relations,
/// otherwise errors.
///
/// INVARIANT: never returns None if `allow_missing=false`.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
@@ -532,8 +539,9 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
) -> Result<Option<BlockNumber>, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
RelationError::InvalidRelnode.into(),
@@ -541,7 +549,15 @@ impl Timeline {
}
if let Some(nblocks) = self.get_cached_rel_size(&tag, version) {
return Ok(nblocks);
return Ok(Some(nblocks));
}
if allow_missing
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
{
return Ok(None);
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
@@ -553,7 +569,7 @@ impl Timeline {
// FSM, and smgrnblocks() on it immediately afterwards,
// without extending it. Tolerate that by claiming that
// any non-existent FSM fork has size 0.
return Ok(0);
return Ok(Some(0));
}
let key = rel_size_to_key(tag);
@@ -562,7 +578,7 @@ impl Timeline {
self.update_cached_rel_size(tag, version, nblocks);
Ok(nblocks)
Ok(Some(nblocks))
}
/// Does the relation exist?
@@ -2912,9 +2928,8 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
mod tests {
use hex_literal::hex;
use pageserver_api::models::ShardParameters;
use pageserver_api::shard::ShardStripeSize;
use utils::id::TimelineId;
use utils::shard::{ShardCount, ShardNumber};
use utils::shard::{ShardCount, ShardNumber, ShardStripeSize};
use super::*;
use crate::DEFAULT_PG_VERSION;

View File

@@ -6161,11 +6161,11 @@ mod tests {
use pageserver_api::keyspace::KeySpaceRandomAccum;
use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease};
use pageserver_compaction::helpers::overlaps_with;
use rand::Rng;
#[cfg(feature = "testing")]
use rand::SeedableRng;
#[cfg(feature = "testing")]
use rand::rngs::StdRng;
use rand::{Rng, thread_rng};
#[cfg(feature = "testing")]
use std::ops::Range;
use storage_layer::{IoConcurrency, PersistentLayerKey};
@@ -6286,8 +6286,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -6330,8 +6330,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -7808,7 +7808,7 @@ mod tests {
for _ in 0..50 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7897,7 +7897,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7965,7 +7965,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -8229,7 +8229,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -8502,7 +8502,7 @@ mod tests {
for iter in 1..=10 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
let blknum = rand::rng().random_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -11291,10 +11291,10 @@ mod tests {
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_read_path() -> anyhow::Result<()> {
use rand::seq::SliceRandom;
use rand::seq::IndexedRandom;
let seed = if cfg!(feature = "fuzz-read-path") {
let seed: u64 = thread_rng().r#gen();
let seed: u64 = rand::rng().random();
seed
} else {
// Use a hard-coded seed when not in fuzzing mode.
@@ -11308,8 +11308,8 @@ mod tests {
let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") {
const QUERIES: u64 = 5000;
let will_init_chance: u8 = random.gen_range(0..=10);
let gap_chance: u8 = random.gen_range(0..=50);
let will_init_chance: u8 = random.random_range(0..=10);
let gap_chance: u8 = random.random_range(0..=50);
(QUERIES, will_init_chance, gap_chance)
} else {
@@ -11410,7 +11410,8 @@ mod tests {
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
let mut selected_key =
start_key.add(random.random_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
if used_keys.contains(&selected_key)
@@ -11425,7 +11426,7 @@ mod tests {
.add_key(selected_key);
used_keys.insert(selected_key);
let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE;
let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE;
if pick_next {
selected_key = selected_key.next();
} else {

View File

@@ -535,8 +535,8 @@ pub(crate) mod tests {
}
pub(crate) fn random_array(len: usize) -> Vec<u8> {
let mut rng = rand::thread_rng();
(0..len).map(|_| rng.r#gen()).collect::<_>()
let mut rng = rand::rng();
(0..len).map(|_| rng.random()).collect::<_>()
}
#[tokio::test]
@@ -588,9 +588,9 @@ pub(crate) mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let blobs = (0..1024)
.map(|_| {
let mut sz: u16 = rng.r#gen();
let mut sz: u16 = rng.random();
// Make 50% of the arrays small
if rng.r#gen() {
if rng.random() {
sz &= 63;
}
random_array(sz.into())

View File

@@ -1090,7 +1090,7 @@ pub(crate) mod tests {
const NUM_KEYS: usize = 100000;
let mut all_data: BTreeMap<u128, u64> = BTreeMap::new();
for idx in 0..NUM_KEYS {
let u: f64 = rand::thread_rng().gen_range(0.0..1.0);
let u: f64 = rand::rng().random_range(0.0..1.0);
let t = -(f64::ln(u));
let key_int = (t * 1000000.0) as u128;
@@ -1116,7 +1116,7 @@ pub(crate) mod tests {
// Test get() operations on random keys, most of which will not exist
for _ in 0..100000 {
let key_int = rand::thread_rng().r#gen::<u128>();
let key_int = rand::rng().random::<u128>();
let search_key = u128::to_be_bytes(key_int);
assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned());
}

View File

@@ -508,8 +508,8 @@ mod tests {
let write_nbytes = cap * 2 + cap / 2;
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(write_nbytes)
.collect();
@@ -565,8 +565,8 @@ mod tests {
let cap = writer.mutable().capacity();
drop(writer);
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(cap * 2 + cap / 2)
.collect();
@@ -614,8 +614,8 @@ mod tests {
let cap = mutable.capacity();
let align = mutable.align();
drop(writer);
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(cap * 2 + cap / 2)
.collect();

View File

@@ -19,7 +19,7 @@ use pageserver_api::shard::{
};
use pageserver_api::upcall_api::ReAttachResponseTenant;
use rand::Rng;
use rand::distributions::Alphanumeric;
use rand::distr::Alphanumeric;
use remote_storage::TimeoutOrCancel;
use sysinfo::SystemExt;
use tokio::fs;
@@ -218,7 +218,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
std::io::ErrorKind::InvalidInput,
"Path must be absolute",
))?;
let rand_suffix = rand::thread_rng()
let rand_suffix = rand::rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
@@ -328,7 +328,7 @@ fn emergency_generations(
LocationMode::Attached(alc) => TenantStartupMode::Attached((
alc.attach_mode,
alc.generation,
ShardStripeSize::default(),
lc.shard.stripe_size,
)),
LocationMode::Secondary(_) => TenantStartupMode::Secondary,
},
@@ -352,7 +352,8 @@ async fn init_load_generations(
let client = StorageControllerUpcallClient::new(conf, cancel);
info!("Calling {} API to re-attach tenants", client.base_url());
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
match client.re_attach(conf).await {
let empty_local_disk = tenant_confs.is_empty();
match client.re_attach(conf, empty_local_disk).await {
Ok(tenants) => tenants
.into_iter()
.flat_map(|(id, rart)| {

View File

@@ -1,8 +1,8 @@
use chrono::NaiveDateTime;
use pageserver_api::shard::ShardStripeSize;
use serde::{Deserialize, Serialize};
use utils::id::TimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardStripeSize;
/// Tenant shard manifest, stored in remote storage. Contains offloaded timelines and other tenant
/// shard-wide information that must be persisted in remote storage.

View File

@@ -25,7 +25,7 @@ pub(super) fn period_jitter(d: Duration, pct: u32) -> Duration {
if d == Duration::ZERO {
d
} else {
rand::thread_rng().gen_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100)
rand::rng().random_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100)
}
}
@@ -35,7 +35,7 @@ pub(super) fn period_warmup(period: Duration) -> Duration {
if period == Duration::ZERO {
period
} else {
rand::thread_rng().gen_range(Duration::ZERO..period)
rand::rng().random_range(Duration::ZERO..period)
}
}

View File

@@ -1634,7 +1634,8 @@ pub(crate) mod test {
use bytes::Bytes;
use itertools::MinMaxResult;
use postgres_ffi::PgMajorVersion;
use rand::prelude::{SeedableRng, SliceRandom, StdRng};
use rand::prelude::{SeedableRng, StdRng};
use rand::seq::IndexedRandom;
use rand::{Rng, RngCore};
/// Construct an index for a fictional delta layer and and then
@@ -1788,14 +1789,14 @@ pub(crate) mod test {
let mut entries = Vec::new();
for _ in 0..constants::KEY_COUNT {
let count = rng.gen_range(1..constants::MAX_ENTRIES_PER_KEY);
let count = rng.random_range(1..constants::MAX_ENTRIES_PER_KEY);
let mut lsns_iter =
std::iter::successors(Some(Lsn(constants::LSN_OFFSET.0 + 0x08)), |lsn| {
Some(Lsn(lsn.0 + 0x08))
});
let mut lsns = Vec::new();
while lsns.len() < count as usize {
let take = rng.gen_bool(0.5);
let take = rng.random_bool(0.5);
let lsn = lsns_iter.next().unwrap();
if take {
lsns.push(lsn);
@@ -1869,12 +1870,13 @@ pub(crate) mod test {
for _ in 0..constants::RANGES_COUNT {
let mut range: Option<Range<Key>> = Option::default();
while range.is_none() || keyspace.overlaps(range.as_ref().unwrap()) {
let range_start = rng.gen_range(start..end);
let range_start = rng.random_range(start..end);
let range_end_offset = range_start + constants::MIN_RANGE_SIZE;
if range_end_offset >= end {
range = Some(Key::from_i128(range_start)..Key::from_i128(end));
} else {
let range_end = rng.gen_range((range_start + constants::MIN_RANGE_SIZE)..end);
let range_end =
rng.random_range((range_start + constants::MIN_RANGE_SIZE)..end);
range = Some(Key::from_i128(range_start)..Key::from_i128(range_end));
}
}

View File

@@ -440,8 +440,8 @@ mod tests {
impl InMemoryFile {
fn new_random(len: usize) -> Self {
Self {
content: rand::thread_rng()
.sample_iter(rand::distributions::Standard)
content: rand::rng()
.sample_iter(rand::distr::StandardUniform)
.take(len)
.collect(),
}
@@ -498,7 +498,7 @@ mod tests {
len
}
};
rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[nread..]); // to discover bugs
rand::Rng::fill(&mut rand::rng(), &mut dst_slice[nread..]); // to discover bugs
Ok((dst, nread))
}
}
@@ -763,7 +763,7 @@ mod tests {
let len = std::cmp::min(dst.bytes_total(), mocked_bytes.len());
let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed();
dst_slice[..len].copy_from_slice(&mocked_bytes[..len]);
rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[len..]); // to discover bugs
rand::Rng::fill(&mut rand::rng(), &mut dst_slice[len..]); // to discover bugs
Ok((dst, len))
}
Err(e) => Err(std::io::Error::other(e)),

View File

@@ -515,7 +515,7 @@ pub(crate) async fn sleep_random_range(
interval: RangeInclusive<Duration>,
cancel: &CancellationToken,
) -> Result<Duration, Cancelled> {
let delay = rand::thread_rng().gen_range(interval);
let delay = rand::rng().random_range(interval);
if delay == Duration::ZERO {
return Ok(delay);
}

View File

@@ -448,6 +448,7 @@ pub struct Timeline {
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_cache: Arc<BasebackupCache>,
#[expect(dead_code)]
feature_resolver: Arc<TenantFeatureResolver>,
}
@@ -2826,7 +2827,7 @@ impl Timeline {
if r.numerator == 0 {
false
} else {
rand::thread_rng().gen_range(0..r.denominator) < r.numerator
rand::rng().random_range(0..r.denominator) < r.numerator
}
}
None => false,
@@ -3908,7 +3909,7 @@ impl Timeline {
// 1hour base
(60_i64 * 60_i64)
// 10min jitter
+ rand::thread_rng().gen_range(-10 * 60..10 * 60),
+ rand::rng().random_range(-10 * 60..10 * 60),
)
.expect("10min < 1hour"),
);

View File

@@ -1326,13 +1326,7 @@ impl Timeline {
.max()
};
let (partition_mode, partition_lsn) = if cfg!(test)
|| cfg!(feature = "testing")
|| self
.feature_resolver
.evaluate_boolean("image-compaction-boundary")
.is_ok()
{
let (partition_mode, partition_lsn) = {
let last_repartition_lsn = self.partitioning.read().1;
let lsn = match l0_l1_boundary_lsn {
Some(boundary) => gc_cutoff
@@ -1348,8 +1342,6 @@ impl Timeline {
} else {
("l0_l1_boundary", lsn)
}
} else {
("latest_record", self.get_last_record_lsn())
};
// 2. Repartition and create image layers if necessary

View File

@@ -654,7 +654,7 @@ mod tests {
use pageserver_api::key::{DBDIR_KEY, Key, rel_block_to_key};
use pageserver_api::models::ShardParameters;
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::ShardStripeSize;
use pageserver_api::shard::DEFAULT_STRIPE_SIZE;
use utils::shard::ShardCount;
use utils::sync::gate::GateGuard;
@@ -955,7 +955,7 @@ mod tests {
});
let child_params = ShardParameters {
count: ShardCount(2),
stripe_size: ShardStripeSize::default(),
stripe_size: DEFAULT_STRIPE_SIZE,
};
let child0 = Arc::new_cyclic(|myself| StubTimeline {
gate: Default::default(),

View File

@@ -1275,8 +1275,8 @@ mod tests {
use std::sync::Arc;
use owned_buffers_io::io_buf_ext::IoBufExt;
use rand::Rng;
use rand::seq::SliceRandom;
use rand::{Rng, thread_rng};
use super::*;
use crate::context::DownloadBehavior;
@@ -1358,7 +1358,7 @@ mod tests {
// Check that all the other FDs still work too. Use them in random order for
// good measure.
file_b_dupes.as_mut_slice().shuffle(&mut thread_rng());
file_b_dupes.as_mut_slice().shuffle(&mut rand::rng());
for vfile in file_b_dupes.iter_mut() {
assert_first_512_eq(vfile, b"content_b").await;
}
@@ -1413,9 +1413,8 @@ mod tests {
let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error);
let hdl = rt.spawn(async move {
let mut buf = IoBufferMut::with_capacity_zeroed(SIZE);
let mut rng = rand::rngs::OsRng;
for _ in 1..1000 {
let f = &files[rng.gen_range(0..files.len())];
let f = &files[rand::rng().random_range(0..files.len())];
buf = f
.read_exact_at(buf.slice_full(), 0, &ctx)
.await

View File

@@ -6,6 +6,7 @@ OBJS = \
$(WIN32RES) \
communicator.o \
communicator_new.o \
communicator_process.o \
extension_server.o \
file_cache.o \
hll.o \
@@ -65,6 +66,8 @@ WALPROP_OBJS = \
# libcommunicator.a is built by cargo from the Rust sources under communicator/
# subdirectory. `cargo build` also generates communicator_bindings.h.
communicator_new.o: communicator/communicator_bindings.h
communicator_process.o: communicator/communicator_bindings.h
file_cache.o: communicator/communicator_bindings.h
$(NEON_CARGO_ARTIFACT_TARGET_DIR)/libcommunicator.a communicator/communicator_bindings.h &:
(cd $(srcdir)/communicator && cargo build $(CARGO_BUILD_FLAGS) $(CARGO_PROFILE))

View File

@@ -1820,12 +1820,12 @@ nm_to_string(NeonMessage *msg)
}
case T_NeonGetPageResponse:
{
#if 0
NeonGetPageResponse *msg_resp = (NeonGetPageResponse *) msg;
#endif
appendStringInfoString(&s, "{\"type\": \"NeonGetPageResponse\"");
appendStringInfo(&s, ", \"page\": \"XXX\"}");
appendStringInfo(&s, ", \"rinfo\": %u/%u/%u", RelFileInfoFmt(msg_resp->req.rinfo));
appendStringInfo(&s, ", \"forknum\": %d", msg_resp->req.forknum);
appendStringInfo(&s, ", \"blkno\": %u", msg_resp->req.blkno);
appendStringInfoChar(&s, '}');
break;
}

View File

@@ -7,6 +7,9 @@ edition.workspace = true
# 'testing' feature is currently unused in the communicator, but we accept it for convenience of
# calling build scripts, so that you can pass the same feature to all packages.
testing = []
# 'rest_broker' feature is currently unused in the communicator, but we accept it for convenience of
# calling build scripts, so that you can pass the same feature to all packages.
rest_broker = []
[lib]
crate-type = ["staticlib"]
@@ -19,12 +22,13 @@ http.workspace = true
libc.workspace = true
nix.workspace = true
atomic_enum = "0.3.0"
measured.workspace = true
prometheus.workspace = true
prost.workspace = true
tonic = { version = "0.12.0", default-features = false, features=["codegen", "prost", "transport"] }
tokio = { version = "1.43.1", features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] }
tokio-pipe = { version = "0.2.12" }
thiserror.workspace = true
tonic = { workspace = true, default-features = false, features=["codegen", "prost", "transport"] }
tokio = { workspace = true, features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] }
tokio-pipe = { version = "0.2.12" }
tracing.workspace = true
tracing-subscriber.workspace = true

View File

@@ -1,11 +1,15 @@
# Communicator
This package provides the so-called "compute-pageserver communicator",
or just "communicator" in short. It runs in a PostgreSQL server, as
part of the neon extension, and handles the communication with the
pageservers. On the PostgreSQL side, the glue code in pgxn/neon/ uses
the communicator to implement the PostgreSQL Storage Manager (SMGR)
interface.
or just "communicator" in short. The communicator is a separate
background worker process that runs in the PostgreSQL server. It's
part of the neon extension.
The commuicator handles the communication with the pageservers, and
also provides an HTTP endpoint for metrics over a local Unix Domain
socket (aka. the "communicator control socket"). On the PostgreSQL
side, the glue code in pgxn/neon/ uses the communicator to implement
the PostgreSQL Storage Manager (SMGR) interface.
## Design criteria
@@ -14,9 +18,14 @@ interface.
## Source code view
pgxn/neon/communicator_process.c
Contains code needed to start up the communicator process, and
the glue that interacts with PostgreSQL code and the Rust
code in the communicator process.
pgxn/neon/communicator_new.c
Contains the glue that interact with PostgreSQL code and the Rust
communicator code.
Contains the backend code that interacts with the communicator
process.
pgxn/neon/communicator/src/backend_interface.rs
The entry point for calls from each backend.
@@ -24,9 +33,6 @@ pgxn/neon/communicator/src/backend_interface.rs
pgxn/neon/communicator/src/init.rs
Initialization at server startup
pgxn/neon/communicator/src/worker_process/
Worker process main loop and glue code
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -215,11 +215,17 @@ pub struct FileCacheIterator {
/// Iterate over LFC contents
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_cache_iterate_begin(_bs: &mut CommunicatorBackendStruct, iter: *mut FileCacheIterator) {
pub extern "C" fn bcomm_cache_iterate_begin(
_bs: &mut CommunicatorBackendStruct,
iter: *mut FileCacheIterator,
) {
unsafe { (*iter).next_bucket = 0 };
}
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_cache_iterate_next(bs: &mut CommunicatorBackendStruct, iter: *mut FileCacheIterator) -> bool {
pub extern "C" fn bcomm_cache_iterate_next(
bs: &mut CommunicatorBackendStruct,
iter: *mut FileCacheIterator,
) -> bool {
use crate::integrated_cache::GetBucketResult;
loop {
let next_bucket = unsafe { (*iter).next_bucket } as usize;
@@ -235,7 +241,7 @@ pub extern "C" fn bcomm_cache_iterate_next(bs: &mut CommunicatorBackendStruct, i
(*iter).next_bucket += 1;
}
break true;
},
}
GetBucketResult::Vacant => {
unsafe {
(*iter).next_bucket += 1;

View File

@@ -759,7 +759,6 @@ impl<'t> IntegratedCacheReadAccess<'t> {
Some((key, _)) => GetBucketResult::Occupied(key.rel, key.block_number),
}
}
}
pub struct BackendCacheReadOp<'t> {

View File

@@ -21,5 +21,9 @@ mod worker_process;
mod global_allocator;
/// Name of the Unix Domain Socket that serves the metrics, and other APIs in the
/// future. This is within the Postgres data directory.
const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket";
// FIXME: get this from postgres headers somehow
pub const BLCKSZ: usize = 8192;

View File

@@ -14,6 +14,8 @@ pub type COid = u32;
// This conveniently matches PG_IOV_MAX
pub const MAX_GETPAGEV_PAGES: usize = 32;
pub const INVALID_BLOCK_NUMBER: u32 = u32::MAX;
use std::ffi::CStr;
use pageserver_page_api::{self as page_api, SlruKind};
@@ -27,7 +29,6 @@ pub enum NeonIORequest {
// Read requests. These are C-friendly variants of the corresponding structs in
// pageserver_page_api.
RelExists(CRelExistsRequest),
RelSize(CRelSizeRequest),
GetPageV(CGetPageVRequest),
ReadSlruSegment(CReadSlruSegmentRequest),
@@ -51,7 +52,7 @@ pub enum NeonIORequest {
#[derive(Copy, Clone, Debug)]
pub enum NeonIOResult {
Empty,
RelExists(bool),
/// InvalidBlockNumber == 0xffffffff means "rel does not exist"
RelSize(u32),
/// the result pages are written to the shared memory addresses given in the request
@@ -85,7 +86,6 @@ impl NeonIORequest {
use NeonIORequest::*;
match self {
Empty => 0,
RelExists(req) => req.request_id,
RelSize(req) => req.request_id,
GetPageV(req) => req.request_id,
ReadSlruSegment(req) => req.request_id,
@@ -164,16 +164,6 @@ impl ShmemBuf {
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelExistsRequest {
pub request_id: u64,
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelSizeRequest {
@@ -182,6 +172,7 @@ pub struct CRelSizeRequest {
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub allow_missing: bool,
}
#[repr(C)]
@@ -317,17 +308,6 @@ pub struct CRelUnlinkRequest {
pub lsn: CLsn,
}
impl CRelExistsRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelSizeRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {

Some files were not shown because too many files have changed in this diff Show More