Compare commits

..

2 Commits

Author SHA1 Message Date
Conrad Ludgate
f18b07aa05 hack around azure_core reqwest features 2025-07-18 21:08:50 +01:00
Conrad Ludgate
019fd30bb9 edit features to remove webpki-roots 2025-07-18 16:22:40 +01:00
224 changed files with 2432 additions and 13285 deletions

View File

@@ -21,14 +21,13 @@ platforms = [
# "x86_64-apple-darwin",
# "x86_64-pc-windows-msvc",
]
[final-excludes]
workspace-members = [
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
"vm_monitor",
# subzero-core is a stub crate that should be excluded from workspace-hack
"subzero-core",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"compute_api",

View File

@@ -1,28 +0,0 @@
name: 'Prepare current job for subzero'
description: >
Set git token to access `neondatabase/subzero` from cargo build,
and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI
inputs:
token:
description: 'GitHub token with access to neondatabase/subzero'
required: true
runs:
using: "composite"
steps:
- name: Set git token for neondatabase/subzero
uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0
env:
SUBZERO_ACCESS_TOKEN: ${{ inputs.token }}
with:
main: |
git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a
post: |
git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"
- name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable
shell: bash -euxo pipefail {0}
run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV}

View File

@@ -86,10 +86,6 @@ jobs:
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Set pg 14 revision for caching
id: pg_v14_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT
@@ -120,7 +116,7 @@ jobs:
ARCH: ${{ inputs.arch }}
SANITIZERS: ${{ inputs.sanitizers }}
run: |
CARGO_FLAGS="--locked --features testing,rest_broker"
CARGO_FLAGS="--locked --features testing"
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_PROFILE=""

View File

@@ -46,10 +46,6 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Cache cargo deps
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0

View File

@@ -54,10 +54,6 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: ./.github/actions/prepare-for-subzero
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Install build dependencies
run: |

View File

@@ -632,8 +632,6 @@ jobs:
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
provenance: false
push: true
pull: true

View File

@@ -72,7 +72,6 @@ jobs:
check-macos-build:
needs: [ check-permissions, files-changed ]
uses: ./.github/workflows/build-macos.yml
secrets: inherit
with:
pg_versions: ${{ needs.files-changed.outputs.postgres_changes }}
rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }}

6
.gitignore vendored
View File

@@ -15,7 +15,6 @@ neon.iml
/.neon
/integration_tests/.neon
compaction-suite-results.*
pgxn/neon/communicator/communicator_bindings.h
docker-compose/docker-compose-parallel.yml
# Coverage
@@ -27,14 +26,9 @@ docker-compose/docker-compose-parallel.yml
*.o
*.so
*.Po
*.pid
# pgindent typedef lists
*.list
# Node
**/node_modules/
# various files for local testing
/proxy/.subzero
local_proxy.json

521
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -49,7 +49,6 @@ members = [
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
"proxy/subzero_core",
]
[workspace.package]
@@ -93,7 +92,6 @@ clap = { version = "4.0", features = ["derive", "env"] }
clashmap = { version = "1.0", features = ["raw-api"] }
comfy-table = "7.1"
const_format = "0.2"
crossbeam-utils = "0.8.21"
crc32c = "0.6"
diatomic-waker = { version = "0.2.3" }
either = "1.8"
@@ -144,29 +142,26 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
once_cell = "1.13"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] }
opentelemetry-semantic-conventions = "0.30"
opentelemetry = "0.27"
opentelemetry_sdk = "0.27"
opentelemetry-otlp = { version = "0.27", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.27"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pem = "3.0.3"
peekable = "0.3.0"
pin-project-lite = "0.2"
pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointer", "prost-codec"] }
procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.13.5"
prost-types = "0.13.5"
rand = "0.9"
# Remove after p256 is updated to 0.14.
rand_core = "=0.6"
rand = "0.8"
redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_30"] }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_27"] }
reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
@@ -179,7 +174,7 @@ scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
send-future = "0.1.0"
sentry = { version = "0.37", default-features = false, features = ["backtrace", "contexts", "panic", "rustls", "reqwest" ] }
sentry = { version = "0.37", default-features = false, features = ["backtrace", "contexts", "panic", "reqwest" ] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
serde_path_to_error = "0.1"
@@ -192,7 +187,6 @@ smallvec = "1.11"
smol_str = { version = "0.2.0", features = ["serde"] }
socket2 = "0.5"
spki = "0.7.3"
spin = "0.9.8"
strum = "0.26"
strum_macros = "0.26"
"subtle" = "2.5.0"
@@ -204,6 +198,7 @@ thiserror = "1.0"
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
tokio = { version = "1.43.1", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"]}
@@ -216,12 +211,15 @@ tonic = { version = "0.13.1", default-features = false, features = ["channel", "
tonic-reflection = { version = "0.13.1", features = ["server"] }
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }
tower-otel = { version = "0.6", features = ["axum"] }
# This revision uses opentelemetry 0.27. There's no tag for it.
tower-otel = { git = "https://github.com/mattiapenati/tower-otel", rev = "56a7321053bcb72443888257b622ba0d43a11fcd" }
tower-service = "0.3.3"
tracing = "0.1"
tracing-error = "0.2"
tracing-log = "0.2"
tracing-opentelemetry = "0.31"
tracing-opentelemetry = "0.28"
tracing-serde = "0.2.0"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"
@@ -242,9 +240,6 @@ x509-cert = { version = "0.2.5" }
env_logger = "0.11"
log = "0.4"
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
uring-common = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
@@ -252,10 +247,10 @@ postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", br
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
## Azure SDK crates
azure_core = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls", "hmac_rust"] }
azure_identity = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
azure_core = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["hmac_rust"] }
azure_identity = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = [] }
azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = [] }
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = [] }
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }

View File

@@ -63,14 +63,7 @@ WORKDIR /home/nonroot
COPY --chown=nonroot . .
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \
cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \
fi \
&& cargo chef prepare --recipe-path recipe.json
RUN cargo chef prepare --recipe-path recipe.json
# Main build image
FROM $REPOSITORY/$IMAGE:$TAG AS build
@@ -78,33 +71,20 @@ WORKDIR /home/nonroot
ARG GIT_VERSION=local
ARG BUILD_TAG
ARG ADDITIONAL_RUSTFLAGS=""
ENV CARGO_FEATURES="default"
# 3. Build cargo dependencies. Note that this step doesn't depend on anything else than
# `recipe.json`, so the layer can be reused as long as none of the dependencies change.
COPY --from=plan /home/nonroot/recipe.json recipe.json
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_NET_GIT_FETCH_WITH_CLI=true && \
git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \
fi \
RUN set -e \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json
# Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build'
# layer, and the cargo dependencies built in the previous step.
COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install
COPY --chown=nonroot . .
COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml
COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock
RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
set -e \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
RUN set -e \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \
--bin pagectl \

View File

@@ -27,10 +27,7 @@ fail.workspace = true
flate2.workspace = true
futures.workspace = true
http.workspace = true
http-body-util.workspace = true
hostname-validator = "1.1"
hyper.workspace = true
hyper-util.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
@@ -47,7 +44,6 @@ postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
ring = "0.17"
scopeguard.workspace = true
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true

View File

@@ -138,12 +138,6 @@ struct Cli {
/// Run in development mode, skipping VM-specific operations like process termination
#[arg(long, action = clap::ArgAction::SetTrue)]
pub dev: bool,
#[arg(long)]
pub pg_init_timeout: Option<u64>,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
pub lakebase_mode: bool,
}
impl Cli {
@@ -194,7 +188,7 @@ fn main() -> Result<()> {
.build()?;
let _rt_guard = runtime.enter();
let tracing_provider = init(cli.dev)?;
runtime.block_on(init(cli.dev))?;
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
@@ -225,8 +219,6 @@ fn main() -> Result<()> {
installed_extensions_collection_interval: Arc::new(AtomicU64::new(
cli.installed_extensions_collection_interval,
)),
pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs),
lakebase_mode: cli.lakebase_mode,
},
config,
)?;
@@ -235,11 +227,11 @@ fn main() -> Result<()> {
scenario.teardown();
deinit_and_exit(tracing_provider, exit_code);
deinit_and_exit(exit_code);
}
fn init(dev_mode: bool) -> Result<Option<tracing_utils::Provider>> {
let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
async fn init(dev_mode: bool) -> Result<()> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL).await?;
let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?;
thread::spawn(move || {
@@ -250,7 +242,7 @@ fn init(dev_mode: bool) -> Result<Option<tracing_utils::Provider>> {
info!("compute build_tag: {}", &BUILD_TAG.to_string());
Ok(provider)
Ok(())
}
fn get_config(cli: &Cli) -> Result<ComputeConfig> {
@@ -275,27 +267,25 @@ fn get_config(cli: &Cli) -> Result<ComputeConfig> {
}
}
fn deinit_and_exit(tracing_provider: Option<tracing_utils::Provider>, exit_code: Option<i32>) -> ! {
if let Some(p) = tracing_provider {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
_ = p.shutdown();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
}
fn deinit_and_exit(exit_code: Option<i32>) -> ! {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may
// hang for quite some time, see, for example:
// - https://github.com/open-telemetry/opentelemetry-rust/issues/868
// - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636
//
// Yet, we want computes to shut down fast enough, as we may need a new one
// for the same timeline ASAP. So wait no longer than 2s for the shutdown to
// complete, then just error out and exit the main thread.
info!("shutting down tracing");
let (sender, receiver) = mpsc::channel();
let _ = thread::spawn(move || {
tracing_utils::shutdown_tracing();
sender.send(()).ok()
});
let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000));
if shutdown_res.is_err() {
error!("timed out while shutting down tracing, exiting anyway");
}
info!("shutting down");

View File

@@ -1,98 +0,0 @@
//! Client for making request to a running Postgres server's communicator control socket.
//!
//! The storage communicator process that runs inside Postgres exposes an HTTP endpoint in
//! a Unix Domain Socket in the Postgres data directory. This provides access to it.
use std::path::Path;
use anyhow::Context;
use hyper::client::conn::http1::SendRequest;
use hyper_util::rt::TokioIo;
/// Name of the socket within the Postgres data directory. This better match that in
/// `pgxn/neon/communicator/src/lib.rs`.
const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket";
/// Open a connection to the communicator's control socket, prepare to send requests to it
/// with hyper.
pub async fn connect_communicator_socket<B>(pgdata: &Path) -> anyhow::Result<SendRequest<B>>
where
B: hyper::body::Body + 'static + Send,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let socket_path = pgdata.join(NEON_COMMUNICATOR_SOCKET_NAME);
let socket_path_len = socket_path.display().to_string().len();
// There is a limit of around 100 bytes (108 on Linux?) on the length of the path to a
// Unix Domain socket. The limit is on the connect(2) function used to open the
// socket, not on the absolute path itself. Postgres changes the current directory to
// the data directory and uses a relative path to bind to the socket, and the relative
// path "./neon-communicator.socket" is always short, but when compute_ctl needs to
// open the socket, we need to use a full path, which can be arbitrarily long.
//
// There are a few ways we could work around this:
//
// 1. Change the current directory to the Postgres data directory and use a relative
// path in the connect(2) call. That's problematic because the current directory
// applies to the whole process. We could change the current directory early in
// compute_ctl startup, and that might be a good idea anyway for other reasons too:
// it would be more robust if the data directory is moved around or unlinked for
// some reason, and you would be less likely to accidentally litter other parts of
// the filesystem with e.g. temporary files. However, that's a pretty invasive
// change.
//
// 2. On Linux, you could open() the data directory, and refer to the the socket
// inside it as "/proc/self/fd/<fd>/neon-communicator.socket". But that's
// Linux-only.
//
// 3. Create a symbolic link to the socket with a shorter path, and use that.
//
// We use the symbolic link approach here. Hopefully the paths we use in production
// are shorter, so that we can open the socket directly, so that this hack is needed
// only in development.
let connect_result = if socket_path_len < 100 {
// We can open the path directly with no hacks.
tokio::net::UnixStream::connect(socket_path).await
} else {
// The path to the socket is too long. Create a symlink to it with a shorter path.
let short_path = std::env::temp_dir().join(format!(
"compute_ctl.short-socket.{}.{}",
std::process::id(),
tokio::task::id()
));
std::os::unix::fs::symlink(&socket_path, &short_path)?;
// Delete the symlink as soon as we have connected to it. There's a small chance
// of leaking if the process dies before we remove it, so try to keep that window
// as small as possible.
scopeguard::defer! {
if let Err(err) = std::fs::remove_file(&short_path) {
tracing::warn!("could not remove symlink \"{}\" created for socket: {}",
short_path.display(), err);
}
}
tracing::info!(
"created symlink \"{}\" for socket \"{}\", opening it now",
short_path.display(),
socket_path.display()
);
tokio::net::UnixStream::connect(&short_path).await
};
let stream = connect_result.context("connecting to communicator control socket")?;
let io = TokioIo::new(stream);
let (request_sender, connection) = hyper::client::conn::http1::handshake(io).await?;
// spawn a task to poll the connection and drive the HTTP state
tokio::spawn(async move {
if let Err(err) = connection.await {
eprintln!("Error in connection: {err}");
}
});
Ok(request_sender)
}

View File

@@ -6,8 +6,7 @@ use compute_api::responses::{
LfcPrewarmState, PromoteState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo,
PageserverProtocol, PageserverShardConnectionInfo, PageserverShardInfo, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -114,11 +113,6 @@ pub struct ComputeNodeParams {
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: Arc<AtomicU64>,
/// Timeout of PG compute startup in the Init state.
pub pg_init_timeout: Option<Duration>,
pub lakebase_mode: bool,
}
type TaskHandle = Mutex<Option<JoinHandle<()>>>;
@@ -160,7 +154,6 @@ pub struct RemoteExtensionMetrics {
#[derive(Clone, Debug)]
pub struct ComputeState {
pub start_time: DateTime<Utc>,
pub pg_start_time: Option<DateTime<Utc>>,
pub status: ComputeStatus,
/// Timestamp of the last Postgres activity. It could be `None` if
/// compute wasn't used since start.
@@ -198,7 +191,6 @@ impl ComputeState {
pub fn new() -> Self {
Self {
start_time: Utc::now(),
pg_start_time: None,
status: ComputeStatus::Empty,
last_active: None,
error: None,
@@ -241,7 +233,7 @@ pub struct ParsedSpec {
pub spec: ComputeSpec,
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub pageserver_conninfo: PageserverConnectionInfo,
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
/// k8s dns name and port
@@ -288,114 +280,26 @@ impl ParsedSpec {
}
}
/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings.
///
/// This is used for backwards-compatilibity, to parse the legacye `pageserver_connstr`
/// field in the compute spec, or the 'neon.pageserver_connstring' GUC. Nowadays, the
/// 'pageserver_connection_info' field should be used instead.
fn extract_pageserver_conninfo_from_connstr(
connstr: &str,
stripe_size: Option<u32>,
) -> Result<PageserverConnectionInfo, anyhow::Error> {
let shard_infos: Vec<_> = connstr
.split(',')
.map(|connstr| PageserverShardInfo {
pageservers: vec![PageserverShardConnectionInfo {
id: None,
libpq_url: Some(connstr.to_string()),
grpc_url: None,
}],
})
.collect();
match shard_infos.len() {
0 => anyhow::bail!("empty connection string"),
1 => {
// We assume that if there's only connection string, it means "unsharded",
// rather than a sharded system with just a single shard. The latter is
// possible in principle, but we never do it.
let shard_count = ShardCount::unsharded();
let only_shard = shard_infos.first().unwrap().clone();
let shards = vec![(ShardIndex::unsharded(), only_shard)];
Ok(PageserverConnectionInfo {
shard_count,
stripe_size: None,
shards: shards.into_iter().collect(),
prefer_protocol: PageserverProtocol::Libpq,
})
}
n => {
if stripe_size.is_none() {
anyhow::bail!("{n} shards but no stripe_size");
}
let shard_count = ShardCount(n.try_into()?);
let shards = shard_infos
.into_iter()
.enumerate()
.map(|(idx, shard_info)| {
(
ShardIndex {
shard_count,
shard_number: ShardNumber(
idx.try_into().expect("shard number fits in u8"),
),
},
shard_info,
)
})
.collect();
Ok(PageserverConnectionInfo {
shard_count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::Libpq,
})
}
}
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = anyhow::Error;
fn try_from(spec: ComputeSpec) -> Result<Self, anyhow::Error> {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
// Extract the options from the spec file that are needed to connect to
// the storage system.
//
// In compute specs generated by old control plane versions, the spec file might
// be missing the `pageserver_connection_info` field. In that case, we need to dig
// the pageserver connection info from the `pageserver_connstr` field instead, or
// if that's missing too, from the GUC in the cluster.settings field.
let mut pageserver_conninfo = spec.pageserver_connection_info.clone();
if pageserver_conninfo.is_none() {
if let Some(pageserver_connstr_field) = &spec.pageserver_connstring {
pageserver_conninfo = Some(extract_pageserver_conninfo_from_connstr(
pageserver_connstr_field,
spec.shard_stripe_size,
)?);
}
}
if pageserver_conninfo.is_none() {
if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") {
let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size")
{
Some(u32::from_str(&guc)?)
} else {
None
};
pageserver_conninfo =
Some(extract_pageserver_conninfo_from_connstr(&guc, stripe_size)?);
}
}
let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!(
"pageserver connection information should be provided"
))?;
// Similarly for safekeeper connection strings
// For backwards-compatibility, the top-level fields in the spec file
// may be empty. In that case, we need to dig them from the GUCs in the
// cluster.settings field.
let pageserver_connstr = spec
.pageserver_connstring
.clone()
.or_else(|| spec.cluster.settings.find("neon.pageserver_connstring"))
.ok_or("pageserver connstr should be provided")?;
let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() {
if matches!(spec.mode, ComputeMode::Primary) {
spec.cluster
.settings
.find("neon.safekeepers")
.ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))?
.ok_or("safekeeper connstrings should be provided")?
.split(',')
.map(|str| str.to_string())
.collect()
@@ -410,22 +314,22 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id {
tenant_id
} else {
let guc = spec
.cluster
spec.cluster
.settings
.find("neon.tenant_id")
.ok_or(anyhow::anyhow!("tenant id should be provided"))?;
TenantId::from_str(&guc).context("invalid tenant id")?
.ok_or("tenant id should be provided")
.map(|s| TenantId::from_str(&s))?
.or(Err("invalid tenant id"))?
};
let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id {
timeline_id
} else {
let guc = spec
.cluster
spec.cluster
.settings
.find("neon.timeline_id")
.ok_or(anyhow::anyhow!("timeline id should be provided"))?;
TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))?
.ok_or("timeline id should be provided")
.map(|s| TimelineId::from_str(&s))?
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<String> = spec
@@ -439,7 +343,7 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
let res = ParsedSpec {
spec,
pageserver_conninfo,
pageserver_connstr,
safekeeper_connstrings,
storage_auth_token,
tenant_id,
@@ -449,7 +353,7 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
};
// Now check validity of the parsed specification
res.validate().map_err(anyhow::Error::msg)?;
res.validate()?;
Ok(res)
}
}
@@ -744,9 +648,6 @@ impl ComputeNode {
};
_this_entered = start_compute_span.enter();
// Hadron: Record postgres start time (used to enforce pg_init_timeout).
state_guard.pg_start_time.replace(Utc::now());
state_guard.set_status(ComputeStatus::Init, &self.state_changed);
compute_state = state_guard.clone()
}
@@ -1139,10 +1040,12 @@ impl ComputeNode {
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let started = Instant::now();
let (connected, size) = match spec.pageserver_conninfo.prefer_protocol {
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
let (connected, size) = match PageserverProtocol::from_connstring(shard0_connstr)? {
PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?,
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
};
self.fix_zenith_signal_neon_signal()?;
@@ -1180,32 +1083,23 @@ impl ComputeNode {
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
let shard0_connstr = spec
.pageserver_connstr
.split(',')
.next()
.unwrap()
.to_string();
let shard_index = match spec.pageserver_connstr.split(',').count() as u8 {
0 | 1 => ShardIndex::unsharded(),
count => ShardIndex::new(ShardNumber(0), ShardCount(count)),
};
let shard0 = spec
.pageserver_conninfo
.shards
.get(&shard0_index)
.ok_or_else(|| {
anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let pageserver = shard0
.pageservers
.first()
.expect("must have at least one pageserver");
let shard0_url = pageserver
.grpc_url
.clone()
.expect("no grpc_url for shard 0");
let (reader, connected) = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::Client::connect(
shard0_url,
shard0_connstr,
spec.tenant_id,
spec.timeline_id,
shard0_index,
shard_index,
spec.storage_auth_token.clone(),
None, // NB: base backups use payload compression
)
@@ -1237,26 +1131,8 @@ impl ComputeNode {
/// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp
/// when the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
};
let shard0 = spec
.pageserver_conninfo
.shards
.get(&shard0_index)
.ok_or_else(|| {
anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let pageserver = shard0
.pageservers
.first()
.expect("must have at least one pageserver");
let shard0_connstr = pageserver
.libpq_url
.clone()
.expect("no libpq_url for shard 0");
let mut config = postgres::Config::from_str(&shard0_connstr)?;
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let mut config = postgres::Config::from_str(shard0_connstr)?;
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
@@ -1342,7 +1218,10 @@ impl ComputeNode {
return result;
}
Err(ref e) if attempts < max_attempts => {
warn!("Failed to get basebackup: {e:?} (attempt {attempts}/{max_attempts})");
warn!(
"Failed to get basebackup: {} (attempt {}/{})",
e, attempts, max_attempts
);
std::thread::sleep(std::time::Duration::from_millis(retry_period_ms as u64));
retry_period_ms *= 1.5;
}
@@ -1550,11 +1429,19 @@ impl ComputeNode {
}
};
self.get_basebackup(compute_state, lsn)
.with_context(|| format!("failed to get basebackup@{lsn}"))?;
info!(
"getting basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
);
self.get_basebackup(compute_state, lsn).with_context(|| {
format!(
"failed to get basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
)
})?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path, None)?;
update_pg_hba(pgdata_path)?;
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
@@ -1859,7 +1746,6 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let lakebase_mode = self.params.lakebase_mode;
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
@@ -1872,7 +1758,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client, lakebase_mode).await {
if let Err(e) = handle_migrations(params, &mut client).await {
error!("Failed to run migrations: {}", e);
}
}
@@ -2495,22 +2381,22 @@ LIMIT 100",
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_conninfo = state
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_conninfo
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_conninfo = &s
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_conninfo;
unchanged = pageserver_conninfo == &old_pageserver_conninfo;
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();
@@ -2740,10 +2626,7 @@ mod tests {
match ParsedSpec::try_from(spec.clone()) {
Ok(_p) => panic!("Failed to detect duplicate entry"),
Err(e) => assert!(
e.to_string()
.starts_with("duplicate entry in safekeeper_connstrings:")
),
Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")),
};
}
}

View File

@@ -15,8 +15,6 @@ use crate::pg_helpers::{
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
use utils::shard::{ShardIndex, ShardNumber};
/// Check that `line` is inside a text file and put it there if it is not.
/// Create file if it doesn't exist.
pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
@@ -58,101 +56,15 @@ pub fn write_postgres_conf(
writeln!(file, "{conf}")?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
writeln!(file)?;
if let Some(conninfo) = &spec.pageserver_connection_info {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = conninfo.stripe_size {
writeln!(
file,
"# from compute spec's pageserver_conninfo.stripe_size field"
)?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
let mut libpq_urls: Option<Vec<String>> = Some(Vec::new());
let mut grpc_urls: Option<Vec<String>> = Some(Vec::new());
let num_shards = if conninfo.shard_count.0 == 0 {
1 // unsharded, treat it as a single shard
} else {
conninfo.shard_count.0
};
for shard_number in 0..num_shards {
let shard_index = ShardIndex {
shard_number: ShardNumber(shard_number),
shard_count: conninfo.shard_count,
};
let info = conninfo.shards.get(&shard_index).ok_or_else(|| {
anyhow::anyhow!(
"shard {shard_index} missing from pageserver_connection_info shard map"
)
})?;
let first_pageserver = info
.pageservers
.first()
.expect("must have at least one pageserver");
// Add the libpq URL to the array, or if the URL is missing, reset the array
// forgetting any previous entries. All servers must have a libpq URL, or none
// at all.
if let Some(url) = &first_pageserver.libpq_url {
if let Some(ref mut urls) = libpq_urls {
urls.push(url.clone());
}
} else {
libpq_urls = None
}
// Similarly for gRPC URLs
if let Some(url) = &first_pageserver.grpc_url {
if let Some(ref mut urls) = grpc_urls {
urls.push(url.clone());
}
} else {
grpc_urls = None
}
}
if let Some(libpq_urls) = libpq_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!(
file,
"neon.pageserver_connstring={}",
escape_conf_value(&libpq_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_connstring")?;
}
if let Some(grpc_urls) = grpc_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!(
file,
"neon.pageserver_grpc_urls={}",
escape_conf_value(&grpc_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_grpc_urls")?;
}
} else {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "# from compute spec's shard_stripe_size field")?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "# from compute spec's pageserver_connstring field")?;
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(

View File

@@ -1,18 +1,10 @@
use std::path::Path;
use std::sync::Arc;
use anyhow::Context;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use hyper::{Request, StatusCode};
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::communicator_socket_client::connect_communicator_socket;
use crate::compute::ComputeNode;
use crate::http::JsonResponse;
use crate::metrics::collect;
@@ -39,42 +31,3 @@ pub(in crate::http) async fn get_metrics() -> Response {
.body(Body::from(buffer))
.unwrap()
}
/// Fetch and forward metrics from the Postgres neon extension's metrics
/// exporter that are used by autoscaling-agent.
///
/// The neon extension exposes these metrics over a Unix domain socket
/// in the data directory. That's not accessible directly from the outside
/// world, so we have this endpoint in compute_ctl to expose it
pub(in crate::http) async fn get_autoscaling_metrics(
State(compute): State<Arc<ComputeNode>>,
) -> Result<Response, Response> {
let pgdata = Path::new(&compute.params.pgdata);
// Connect to the communicator process's metrics socket
let mut metrics_client = connect_communicator_socket(pgdata)
.await
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Make a request for /autoscaling_metrics
let request = Request::builder()
.method("GET")
.uri("/autoscaling_metrics")
.header("Host", "localhost") // hyper requires Host, even though the server won't care
.body(Body::from(""))
.unwrap();
let resp = metrics_client
.send_request(request)
.await
.context("fetching metrics from Postgres metrics service")
.map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?;
// Build a response that just forwards the response we got.
let mut response = Response::builder();
response = response.status(resp.status());
if let Some(content_type) = resp.headers().get(CONTENT_TYPE) {
response = response.header(CONTENT_TYPE, content_type);
}
let body = tonic::service::AxumBody::from_stream(resp.into_body().into_data_stream());
Ok(response.body(body).unwrap())
}

View File

@@ -81,12 +81,8 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Server::External {
config, compute_id, ..
} => {
let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/metrics", get(metrics::get_metrics))
.route(
"/autoscaling_metrics",
get(metrics::get_autoscaling_metrics),
);
let unauthenticated_router =
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))

View File

@@ -4,7 +4,6 @@
#![deny(clippy::undocumented_unsafe_blocks)]
pub mod checker;
pub mod communicator_socket_client;
pub mod config;
pub mod configurator;
pub mod http;

View File

@@ -13,9 +13,7 @@ use tracing_subscriber::prelude::*;
/// set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`. See
/// `tracing-utils` package description.
///
pub fn init_tracing_and_logging(
default_log_level: &str,
) -> anyhow::Result<Option<tracing_utils::Provider>> {
pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
// Initialize Logging
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level));
@@ -26,9 +24,8 @@ pub fn init_tracing_and_logging(
.with_writer(std::io::stderr);
// Initialize OpenTelemetry
let provider =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default());
let otlp_layer = provider.as_ref().map(tracing_utils::layer);
let otlp_layer =
tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()).await;
// Put it all together
tracing_subscriber::registry()
@@ -40,7 +37,7 @@ pub fn init_tracing_and_logging(
utils::logging::replace_panic_hook_with_tracing_panic_hook().forget();
Ok(provider)
Ok(())
}
/// Replace all newline characters with a special character to make it

View File

@@ -4,13 +4,14 @@ use std::thread;
use std::time::{Duration, SystemTime};
use anyhow::{Result, bail};
use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol};
use compute_api::spec::{ComputeMode, PageserverProtocol};
use itertools::Itertools as _;
use pageserver_page_api as page_api;
use postgres::{NoTls, SimpleQueryMessage};
use tracing::{info, warn};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::shard::TenantShardId;
use utils::shard::{ShardCount, ShardNumber, TenantShardId};
use crate::compute::ComputeNode;
@@ -77,16 +78,17 @@ fn acquire_lsn_lease_with_retry(
loop {
// Note: List of pageservers is dynamic, need to re-read configs before each attempt.
let (conninfo, auth) = {
let (connstrings, auth) = {
let state = compute.state.lock().unwrap();
let spec = state.pspec.as_ref().expect("spec must be set");
(
spec.pageserver_conninfo.clone(),
spec.pageserver_connstr.clone(),
spec.storage_auth_token.clone(),
)
};
let result = try_acquire_lsn_lease(conninfo, auth.as_deref(), tenant_id, timeline_id, lsn);
let result =
try_acquire_lsn_lease(&connstrings, auth.as_deref(), tenant_id, timeline_id, lsn);
match result {
Ok(Some(res)) => {
return Ok(res);
@@ -110,44 +112,35 @@ fn acquire_lsn_lease_with_retry(
/// Tries to acquire LSN leases on all Pageserver shards.
fn try_acquire_lsn_lease(
conninfo: PageserverConnectionInfo,
connstrings: &str,
auth: Option<&str>,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<Option<SystemTime>> {
let connstrings = connstrings.split(',').collect_vec();
let shard_count = connstrings.len();
let mut leases = Vec::new();
for (shard_index, shard) in conninfo.shards.into_iter() {
let tenant_shard_id = TenantShardId {
tenant_id,
shard_number: shard_index.shard_number,
shard_count: shard_index.shard_count,
for (shard_number, &connstring) in connstrings.iter().enumerate() {
let tenant_shard_id = match shard_count {
0 | 1 => TenantShardId::unsharded(tenant_id),
shard_count => TenantShardId {
tenant_id,
shard_number: ShardNumber(shard_number as u8),
shard_count: ShardCount::new(shard_count as u8),
},
};
// XXX: If there are more than pageserver for the one shard, do we need to get a
// leas on all of them? Currently, that's what we assume, but this is hypothetical
// as of this writing, as we never pass the info for more than one pageserver per
// shard.
for pageserver in shard.pageservers {
let lease = match conninfo.prefer_protocol {
PageserverProtocol::Grpc => acquire_lsn_lease_grpc(
&pageserver.grpc_url.unwrap(),
auth,
tenant_shard_id,
timeline_id,
lsn,
)?,
PageserverProtocol::Libpq => acquire_lsn_lease_libpq(
&pageserver.libpq_url.unwrap(),
auth,
tenant_shard_id,
timeline_id,
lsn,
)?,
};
leases.push(lease);
}
let lease = match PageserverProtocol::from_connstring(connstring)? {
PageserverProtocol::Libpq => {
acquire_lsn_lease_libpq(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
PageserverProtocol::Grpc => {
acquire_lsn_lease_grpc(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
};
leases.push(lease);
}
Ok(leases.into_iter().min().flatten())

View File

@@ -9,20 +9,15 @@ use crate::metrics::DB_MIGRATION_FAILED;
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
lakebase_mode: bool,
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str], lakebase_mode: bool) -> Self {
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
Self {
client,
migrations,
lakebase_mode,
}
Self { client, migrations }
}
/// Get the current value neon_migration.migration_id
@@ -135,13 +130,8 @@ impl<'m> MigrationRunner<'m> {
// ID is also the next index
let migration_id = (current_migration + 1) as i64;
let migration = self.migrations[current_migration];
let migration = if self.lakebase_mode {
migration.replace("neon_superuser", "databricks_superuser")
} else {
migration.to_string()
};
match Self::run_migration(self.client, migration_id, &migration).await {
match Self::run_migration(self.client, migration_id, migration).await {
Ok(_) => {
info!("Finished migration id={}", migration_id);
}

View File

@@ -11,7 +11,6 @@ use tracing::{Level, error, info, instrument, span};
use crate::compute::ComputeNode;
use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS};
const PG_DEFAULT_INIT_TIMEOUIT: Duration = Duration::from_secs(60);
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
/// Struct to store runtime state of the compute monitor thread.
@@ -353,47 +352,13 @@ impl ComputeMonitor {
// Hang on condition variable waiting until the compute status is `Running`.
fn wait_for_postgres_start(compute: &ComputeNode) {
let mut state = compute.state.lock().unwrap();
let pg_init_timeout = compute
.params
.pg_init_timeout
.unwrap_or(PG_DEFAULT_INIT_TIMEOUIT);
while state.status != ComputeStatus::Running {
info!("compute is not running, waiting before monitoring activity");
if !compute.params.lakebase_mode {
state = compute.state_changed.wait(state).unwrap();
state = compute.state_changed.wait(state).unwrap();
if state.status == ComputeStatus::Running {
break;
}
continue;
if state.status == ComputeStatus::Running {
break;
}
if state.pg_start_time.is_some()
&& Utc::now()
.signed_duration_since(state.pg_start_time.unwrap())
.to_std()
.unwrap_or_default()
> pg_init_timeout
{
// If Postgres isn't up and running with working PS/SK connections within POSTGRES_STARTUP_TIMEOUT, it is
// possible that we started Postgres with a wrong spec (so it is talking to the wrong PS/SK nodes). To prevent
// deadends we simply exit (panic) the compute node so it can restart with the latest spec.
//
// NB: We skip this check if we have not attempted to start PG yet (indicated by state.pg_start_up == None).
// This is to make sure the more appropriate errors are surfaced if we encounter issues before we even attempt
// to start PG (e.g., if we can't pull the spec, can't sync safekeepers, or can't get the basebackup).
error!(
"compute did not enter Running state in {} seconds, exiting",
pg_init_timeout.as_secs()
);
std::process::exit(1);
}
state = compute
.state_changed
.wait_timeout(state, Duration::from_secs(5))
.unwrap()
.0;
}
}

View File

@@ -11,9 +11,7 @@ use std::time::{Duration, Instant};
use anyhow::{Result, bail};
use compute_api::responses::TlsConfig;
use compute_api::spec::{
Database, DatabricksSettings, GenericOption, GenericOptions, PgIdent, Role,
};
use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
use futures::StreamExt;
use indexmap::IndexMap;
use ini::Ini;
@@ -186,42 +184,6 @@ impl DatabaseExt for Database {
}
}
pub trait DatabricksSettingsExt {
fn as_pg_settings(&self) -> String;
}
impl DatabricksSettingsExt for DatabricksSettings {
fn as_pg_settings(&self) -> String {
// Postgres GUCs rendered from DatabricksSettings
vec![
// ssl_ca_file
Some(format!(
"ssl_ca_file = '{}'",
self.pg_compute_tls_settings.ca_file
)),
// [Optional] databricks.workspace_url
Some(format!(
"databricks.workspace_url = '{}'",
&self.databricks_workspace_host
)),
// todo(vikas.jain): these are not required anymore as they are moved to static
// conf but keeping these to avoid image mismatch between hcc and pg.
// Once hcc and pg are in sync, we can remove these.
//
// databricks.enable_databricks_identity_login
Some("databricks.enable_databricks_identity_login = true".to_string()),
// databricks.enable_sql_restrictions
Some("databricks.enable_sql_restrictions = true".to_string()),
]
.into_iter()
// Removes `None`s
.flatten()
.collect::<Vec<String>>()
.join("\n")
+ "\n"
}
}
/// Generic trait used to provide quoting / encoding for strings used in the
/// Postgres SQL queries and DATABASE_URL.
pub trait Escaping {

View File

@@ -1,6 +1,4 @@
use std::fs::File;
use std::fs::{self, Permissions};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use anyhow::{Result, anyhow, bail};
@@ -135,25 +133,10 @@ pub fn get_config_from_control_plane(base_uri: &str, compute_id: &str) -> Result
}
/// Check `pg_hba.conf` and update if needed to allow external connections.
pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> Result<()> {
pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> {
// XXX: consider making it a part of config.json
let pghba_path = pgdata_path.join("pg_hba.conf");
// Update pg_hba to contains databricks specfic settings before adding neon settings
// PG uses the first record that matches to perform authentication, so we need to have
// our rules before the default ones from neon.
// See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html
if let Some(databricks_pg_hba) = databricks_pg_hba {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_hba),
)? {
info!("updated pg_hba.conf to include databricks_pg_hba.conf");
} else {
info!("pg_hba.conf already included databricks_pg_hba.conf");
}
}
if config::line_in_file(&pghba_path, PG_HBA_ALL_MD5)? {
info!("updated pg_hba.conf to allow external connections");
} else {
@@ -163,59 +146,6 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) ->
Ok(())
}
/// Check `pg_ident.conf` and update if needed to allow databricks config.
pub fn update_pg_ident(pgdata_path: &Path, databricks_pg_ident: Option<&String>) -> Result<()> {
info!("checking pg_ident.conf");
let pghba_path = pgdata_path.join("pg_ident.conf");
// Update pg_ident to contains databricks specfic settings
if let Some(databricks_pg_ident) = databricks_pg_ident {
if config::line_in_file(
&pghba_path,
&format!("include_if_exists {}\n", *databricks_pg_ident),
)? {
info!("updated pg_ident.conf to include databricks_pg_ident.conf");
} else {
info!("pg_ident.conf already included databricks_pg_ident.conf");
}
}
Ok(())
}
/// Copy tls key_file and cert_file from k8s secret mount directory
/// to pgdata and set private key file permissions as expected by Postgres.
/// See this doc for expected permission <https://www.postgresql.org/docs/current/ssl-tcp.html>
/// K8s secrets mount on dblet does not honor permission and ownership
/// specified in the Volume or VolumeMount. So we need to explicitly copy the file and set the permissions.
pub fn copy_tls_certificates(
key_file: &String,
cert_file: &String,
pgdata_path: &Path,
) -> Result<()> {
let files = [cert_file, key_file];
for file in files.iter() {
let source = Path::new(file);
let dest = pgdata_path.join(source.file_name().unwrap());
if !dest.exists() {
std::fs::copy(source, &dest)?;
info!(
"Copying tls file: {} to {}",
&source.display(),
&dest.display()
);
}
if *file == key_file {
// Postgres requires private key to be readable only by the owner by having
// chmod 600 permissions.
let permissions = Permissions::from_mode(0o600);
fs::set_permissions(&dest, permissions)?;
info!("Setting permission on {}.", &dest.display());
}
}
Ok(())
}
/// Create a standby.signal file
pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
// XXX: consider making it a part of config.json
@@ -240,11 +170,7 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(
params: ComputeNodeParams,
client: &mut Client,
lakebase_mode: bool,
) -> Result<()> {
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -308,7 +234,7 @@ pub async fn handle_migrations(
),
];
MigrationRunner::new(client, &migrations, lakebase_mode)
MigrationRunner::new(client, &migrations)
.run_migrations()
.await?;

View File

@@ -411,8 +411,7 @@ impl ComputeNode {
.map(|limit| match limit {
0..10 => limit,
10..30 => 10,
30..300 => limit / 3,
300.. => 100,
30.. => limit / 3,
})
// If we didn't find max_connections, default to 10 concurrent connections.
.unwrap_or(10)

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
```shell
cargo neon init
cargo neon init --pg-version 16
cargo neon start
cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16

View File

@@ -16,14 +16,9 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::{
ComputeMode, PageserverConnectionInfo, PageserverProtocol, PageserverShardInfo,
};
use compute_api::spec::{ComputeMode, PageserverProtocol};
use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode};
use control_plane::endpoint::{
pageserver_conf_to_shard_conn_info, tenant_locate_response_to_conn_info,
};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
@@ -49,6 +44,7 @@ use pageserver_api::models::{
};
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
@@ -56,11 +52,11 @@ use safekeeper_api::{
};
use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR;
use tokio::task::JoinSet;
use url::Host;
use utils::auth::{Claims, Scope};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use utils::lsn::Lsn;
use utils::project_git_version;
use utils::shard::ShardIndex;
// Default id of a safekeeper node, if not specified on the command line.
const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1);
@@ -411,12 +407,6 @@ struct StorageControllerStartCmdArgs {
help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)"
)]
base_port: Option<u16>,
#[clap(
long,
help = "Whether the storage controller should handle pageserver-reported local disk loss events."
)]
handle_ps_local_disk_loss: Option<bool>,
}
#[derive(clap::Args)]
@@ -1531,56 +1521,62 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
)?;
}
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id {
let conf = env.get_pageserver_conf(ps_id).unwrap();
let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let shard_info = PageserverShardInfo {
pageservers: vec![ps_conninfo],
let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id {
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
// Use gRPC if requested.
let pageserver = if endpoint.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
(PageserverProtocol::Grpc, host, port)
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
(PageserverProtocol::Libpq, host, port)
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
.into_iter()
.collect();
PageserverConnectionInfo {
shard_count: ShardCount(0),
stripe_size: None,
shards,
prefer_protocol,
}
(vec![pageserver], DEFAULT_STRIPE_SIZE)
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
assert!(!locate_result.shards.is_empty());
// Initialize LSN leases for static computes.
if let ComputeMode::Static(lsn) = endpoint.mode {
futures::future::try_join_all(locate_result.shards.iter().map(
|shard| async move {
let pageservers = futures::future::try_join_all(
locate_result.shards.into_iter().map(|shard| async move {
if let ComputeMode::Static(lsn) = endpoint.mode {
// Initialize LSN leases for static computes.
let conf = env.get_pageserver_conf(shard.node_id).unwrap();
let pageserver = PageServerNode::from_env(env, conf);
pageserver
.http_client
.timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn)
.await
},
))
.await?;
}
.await?;
}
tenant_locate_response_to_conn_info(&locate_result)?
let pageserver = if endpoint.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))?,
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr)?,
shard.listen_pg_port,
)
};
anyhow::Ok(pageserver)
}),
)
.await?;
let stripe_size = locate_result.shard_params.stripe_size;
(pageservers, stripe_size)
};
pageserver_conninfo.prefer_protocol = prefer_protocol;
assert!(!pageservers.is_empty());
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
@@ -1610,8 +1606,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageserver_conninfo,
pageservers,
remote_ext_base_url: remote_ext_base_url.clone(),
shard_stripe_size: stripe_size.0 as usize,
create_test_user: args.create_test_user,
start_timeout: args.start_timeout,
autoprewarm: args.autoprewarm,
@@ -1628,45 +1625,51 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.endpoints
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id {
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let conf = env.get_pageserver_conf(ps_id)?;
let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let shard_info = PageserverShardInfo {
pageservers: vec![ps_conninfo],
// Use gRPC if requested.
let pageserver = if endpoint.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
(PageserverProtocol::Grpc, host, port)
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
(PageserverProtocol::Libpq, host, port)
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
.into_iter()
.collect();
PageserverConnectionInfo {
shard_count: ShardCount::unsharded(),
stripe_size: None,
shards,
prefer_protocol,
}
vec![pageserver]
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
tenant_locate_response_to_conn_info(&locate_result)?
storage_controller
.tenant_locate(endpoint.tenant_id)
.await?
.shards
.into_iter()
.map(|shard| {
// Use gRPC if requested.
if endpoint.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))
.expect("bad hostname"),
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr).expect("bad hostname"),
shard.listen_pg_port,
)
}
})
.collect::<Vec<_>>()
};
pageserver_conninfo.prefer_protocol = prefer_protocol;
// If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env.
let safekeepers = parse_safekeepers(&args.safekeepers)?;
endpoint
.reconfigure(Some(&pageserver_conninfo), safekeepers, None)
.reconfigure(Some(pageservers), None, safekeepers, None)
.await?;
}
EndpointCmd::Stop(args) => {
@@ -1806,7 +1809,6 @@ async fn handle_storage_controller(
instance_id: args.instance_id,
base_port: args.base_port,
start_timeout: args.start_timeout,
handle_ps_local_disk_loss: args.handle_ps_local_disk_loss,
};
if let Err(e) = svc.start(start_args).await {

View File

@@ -37,7 +37,7 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
@@ -58,17 +58,14 @@ use compute_api::responses::{
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
PageserverShardInfo, PgIdent, RemoteExtSpec, Role,
PgIdent, RemoteExtSpec, Role,
};
// re-export these, because they're used in the reconfigure() function
pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionInfo};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
};
use nix::sys::signal::{Signal, kill};
use pageserver_api::shard::ShardStripeSize;
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use safekeeper_api::PgMajorVersion;
@@ -78,11 +75,8 @@ use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use tracing::debug;
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardIndex, ShardNumber};
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
@@ -393,8 +387,9 @@ pub struct EndpointStartArgs {
pub endpoint_storage_addr: String,
pub safekeepers_generation: Option<SafekeeperGeneration>,
pub safekeepers: Vec<NodeId>,
pub pageserver_conninfo: PageserverConnectionInfo,
pub pageservers: Vec<(PageserverProtocol, Host, u16)>,
pub remote_ext_base_url: Option<String>,
pub shard_stripe_size: usize,
pub create_test_user: bool,
pub start_timeout: Duration,
pub autoprewarm: bool,
@@ -667,6 +662,14 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String {
pageservers
.iter()
.map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
/// Map safekeepers ids to the actual connection strings.
fn build_safekeepers_connstrs(&self, sk_ids: Vec<NodeId>) -> Result<Vec<String>> {
let mut safekeeper_connstrings = Vec::new();
@@ -712,6 +715,9 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&args.pageservers);
assert!(!pageserver_connstring.is_empty());
let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?;
// check for file remote_extensions_spec.json
@@ -726,45 +732,6 @@ impl Endpoint {
remote_extensions = None;
};
// For the sake of backwards-compatibility, also fill in 'pageserver_connstring'
//
// Use a closure so that we can conviniently return None in the middle of the
// loop.
let pageserver_connstring: Option<String> = (|| {
let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() {
1
} else {
args.pageserver_conninfo.shard_count.0
};
let mut connstrings = Vec::new();
for shard_no in 0..num_shards {
let shard_index = ShardIndex {
shard_count: args.pageserver_conninfo.shard_count,
shard_number: ShardNumber(shard_no),
};
let shard = args
.pageserver_conninfo
.shards
.get(&shard_index)
.ok_or_else(|| {
anyhow!(
"shard {} not found in pageserver_connection_info",
shard_index
)
})?;
let pageserver = shard
.pageservers
.first()
.ok_or(anyhow!("must have at least one pageserver"))?;
if let Some(libpq_url) = &pageserver.libpq_url {
connstrings.push(libpq_url.clone());
} else {
return Ok::<_, anyhow::Error>(None);
}
}
Ok(Some(connstrings.join(",")))
})()?;
// Create config file
let config = {
let mut spec = ComputeSpec {
@@ -809,14 +776,13 @@ impl Endpoint {
branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode,
pageserver_connection_info: Some(args.pageserver_conninfo.clone()),
pageserver_connstring,
pageserver_connstring: Some(pageserver_connstring),
safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: args.auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size
shard_stripe_size: Some(args.shard_stripe_size),
local_proxy_config: None,
reconfigure_concurrency: self.reconfigure_concurrency,
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
@@ -1028,7 +994,8 @@ impl Endpoint {
pub async fn reconfigure(
&self,
pageserver_conninfo: Option<&PageserverConnectionInfo>,
pageservers: Option<Vec<(PageserverProtocol, Host, u16)>>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
safekeeper_generation: Option<SafekeeperGeneration>,
) -> Result<()> {
@@ -1043,15 +1010,15 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
if let Some(pageserver_conninfo) = pageserver_conninfo {
// If pageservers are provided, we need to ensure that they are not empty.
// This is a requirement for the compute_ctl configuration.
anyhow::ensure!(
!pageserver_conninfo.shards.is_empty(),
"no pageservers provided"
);
spec.pageserver_connection_info = Some(pageserver_conninfo.clone());
spec.shard_stripe_size = pageserver_conninfo.stripe_size;
// If pageservers are not specified, don't change them.
if let Some(pageservers) = pageservers {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
}
}
// If safekeepers are not specified, don't change them.
@@ -1100,9 +1067,11 @@ impl Endpoint {
pub async fn reconfigure_pageservers(
&self,
pageservers: &PageserverConnectionInfo,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
stripe_size: Option<ShardStripeSize>,
) -> Result<()> {
self.reconfigure(Some(pageservers), None, None).await
self.reconfigure(Some(pageservers), stripe_size, None, None)
.await
}
pub async fn reconfigure_safekeepers(
@@ -1110,7 +1079,7 @@ impl Endpoint {
safekeepers: Vec<NodeId>,
generation: SafekeeperGeneration,
) -> Result<()> {
self.reconfigure(None, Some(safekeepers), Some(generation))
self.reconfigure(None, None, Some(safekeepers), Some(generation))
.await
}
@@ -1166,68 +1135,3 @@ impl Endpoint {
)
}
}
pub fn pageserver_conf_to_shard_conn_info(
conf: &crate::local_env::PageServerConf,
) -> Result<PageserverShardConnectionInfo> {
let libpq_url = {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
Ok(PageserverShardConnectionInfo {
id: Some(conf.id.to_string()),
libpq_url,
grpc_url,
})
}
pub fn tenant_locate_response_to_conn_info(
response: &pageserver_api::controller_api::TenantLocateResponse,
) -> Result<PageserverConnectionInfo> {
let mut shards = HashMap::new();
for shard in response.shards.iter() {
tracing::info!("parsing {}", shard.listen_pg_addr);
let libpq_url = {
let host = &shard.listen_pg_addr;
let port = shard.listen_pg_port;
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr {
let host = grpc_addr;
let port = shard.listen_grpc_port.expect("no gRPC port");
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
let shard_info = PageserverShardInfo {
pageservers: vec![PageserverShardConnectionInfo {
id: Some(shard.node_id.to_string()),
libpq_url,
grpc_url,
}],
};
shards.insert(shard.shard_id.to_index(), shard_info);
}
let stripe_size = if response.shard_params.count.is_unsharded() {
None
} else {
Some(response.shard_params.stripe_size.0)
};
Ok(PageserverConnectionInfo {
shard_count: response.shard_params.count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::default(),
})
}

View File

@@ -56,7 +56,6 @@ pub struct NeonStorageControllerStartArgs {
pub instance_id: u8,
pub base_port: Option<u16>,
pub start_timeout: humantime::Duration,
pub handle_ps_local_disk_loss: Option<bool>,
}
impl NeonStorageControllerStartArgs {
@@ -65,7 +64,6 @@ impl NeonStorageControllerStartArgs {
instance_id: 1,
base_port: None,
start_timeout,
handle_ps_local_disk_loss: None,
}
}
}
@@ -671,10 +669,6 @@ impl StorageController {
println!("Starting storage controller at {scheme}://{host}:{listen_port}");
if start_args.handle_ps_local_disk_loss.unwrap_or_default() {
args.push("--handle-ps-local-disk-loss".to_string());
}
background_process::start_process(
COMMAND,
&instance_dir,

View File

@@ -35,7 +35,6 @@ reason = "The paste crate is a build-only dependency with no runtime components.
# More documentation for the licenses section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html
[licenses]
version = 2
allow = [
"0BSD",
"Apache-2.0",

View File

@@ -233,7 +233,7 @@ mod tests {
.unwrap()
.as_millis();
use rand::Rng;
let random = rand::rng().random::<u32>();
let random = rand::thread_rng().r#gen::<u32>();
let s3_config = remote_storage::S3Config {
bucket_name: var(REAL_S3_BUCKET).unwrap(),

View File

@@ -14,7 +14,6 @@ use serde::{Deserialize, Serialize};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::shard::{ShardCount, ShardIndex};
use crate::responses::TlsConfig;
@@ -106,17 +105,6 @@ pub struct ComputeSpec {
// updated to fill these fields, we can make these non optional.
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
/// Pageserver information can be passed in three different ways:
/// 1. Here in `pageserver_connection_info`
/// 2. In the `pageserver_connstring` field.
/// 3. in `cluster.settings`.
///
/// The goal is to use method 1. everywhere. But for backwards-compatibility with old
/// versions of the control plane, `compute_ctl` will check 2. and 3. if the
/// `pageserver_connection_info` field is missing.
pub pageserver_connection_info: Option<PageserverConnectionInfo>,
pub pageserver_connstring: Option<String>,
// More neon ids that we expose to the compute_ctl
@@ -153,7 +141,7 @@ pub struct ComputeSpec {
// Stripe size for pageserver sharding, in pages
#[serde(default)]
pub shard_stripe_size: Option<u32>,
pub shard_stripe_size: Option<usize>,
/// Local Proxy configuration used for JWT authentication
#[serde(default)]
@@ -226,32 +214,6 @@ pub enum ComputeFeature {
UnknownFeature,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverConnectionInfo {
/// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage
pub shard_count: ShardCount,
/// INVARIANT: null if shard_count is 0, otherwise non-null and immutable
pub stripe_size: Option<u32>,
pub shards: HashMap<ShardIndex, PageserverShardInfo>,
#[serde(default)]
pub prefer_protocol: PageserverProtocol,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardInfo {
pub pageservers: Vec<PageserverShardConnectionInfo>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardConnectionInfo {
pub id: Option<String>,
pub libpq_url: Option<String>,
pub grpc_url: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct RemoteExtSpec {
pub public_extensions: Option<Vec<String>>,
@@ -369,12 +331,6 @@ impl ComputeMode {
}
}
impl Display for ComputeMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.to_type_str())
}
}
/// Log level for audit logging
#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
pub enum ComputeAudit {
@@ -460,32 +416,6 @@ pub struct GenericOption {
pub vartype: String,
}
/// Postgres compute TLS settings.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct PgComputeTlsSettings {
// Absolute path to the certificate file for server-side TLS.
pub cert_file: String,
// Absolute path to the private key file for server-side TLS.
pub key_file: String,
// Absolute path to the certificate authority file for verifying client certificates.
pub ca_file: String,
}
/// Databricks specific options for compute instance.
/// This is used to store any other settings that needs to be propagate to Compute
/// but should not be persisted to ComputeSpec in the database.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct DatabricksSettings {
pub pg_compute_tls_settings: PgComputeTlsSettings,
// Absolute file path to databricks_pg_hba.conf file.
pub databricks_pg_hba: String,
// Absolute file path to databricks_pg_ident.conf file.
pub databricks_pg_ident: String,
// Hostname portion of the Databricks workspace URL of the endpoint, or empty string if not known.
// A valid hostname is required for the compute instance to support PAT logins.
pub databricks_workspace_host: String,
}
/// Optional collection of `GenericOption`'s. Type alias allows us to
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
@@ -511,15 +441,13 @@ pub struct JwksSettings {
pub jwt_audience: Option<String>,
}
/// Protocol used to connect to a Pageserver.
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PageserverProtocol {
/// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme.
#[default]
#[serde(rename = "libpq")]
Libpq,
/// A newer, gRPC-based protocol. Uses grpc:// scheme.
#[serde(rename = "grpc")]
Grpc,
}

View File

@@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> {
IdempotencyKey {
now: Utc::now(),
node_id,
nonce: rand::rng().random_range(0..=9999),
nonce: rand::thread_rng().gen_range(0..=9999),
}
}

View File

@@ -41,7 +41,7 @@ impl NodeOs {
/// Generate a random number in range [0, max).
pub fn random(&self, max: u64) -> u64 {
self.internal.rng.lock().random_range(0..max)
self.internal.rng.lock().gen_range(0..max)
}
/// Append a new event to the world event log.

View File

@@ -32,10 +32,10 @@ impl Delay {
/// Generate a random delay in range [min, max]. Return None if the
/// message should be dropped.
pub fn delay(&self, rng: &mut StdRng) -> Option<u64> {
if rng.random_bool(self.fail_prob) {
if rng.gen_bool(self.fail_prob) {
return None;
}
Some(rng.random_range(self.min..=self.max))
Some(rng.gen_range(self.min..=self.max))
}
}

View File

@@ -69,7 +69,7 @@ impl World {
/// Create a new random number generator.
pub fn new_rng(&self) -> StdRng {
let mut rng = self.rng.lock();
StdRng::from_rng(rng.deref_mut())
StdRng::from_rng(rng.deref_mut()).unwrap()
}
/// Create a new node.

View File

@@ -17,5 +17,5 @@ procfs.workspace = true
measured-process.workspace = true
[dev-dependencies]
rand.workspace = true
rand_distr = "0.5"
rand = "0.8"
rand_distr = "0.4.3"

View File

@@ -260,7 +260,7 @@ mod tests {
#[test]
fn test_cardinality_small() {
let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap());
assert_eq!(actual, [46, 30, 32]);
assert!(51.3 < estimate[0] && estimate[0] < 51.4);
@@ -270,7 +270,7 @@ mod tests {
#[test]
fn test_cardinality_medium() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap());
assert_eq!(actual, [2529, 1618, 1629]);
assert!(2309.1 < estimate[0] && estimate[0] < 2309.2);
@@ -280,8 +280,7 @@ mod tests {
#[test]
fn test_cardinality_large() {
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap());
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap());
assert_eq!(actual, [129077, 79579, 79630]);
assert!(126067.2 < estimate[0] && estimate[0] < 126067.3);
@@ -291,7 +290,7 @@ mod tests {
#[test]
fn test_cardinality_small2() {
let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap());
assert_eq!(actual, [92, 58, 60]);
assert!(116.1 < estimate[0] && estimate[0] < 116.2);
@@ -301,7 +300,7 @@ mod tests {
#[test]
fn test_cardinality_medium2() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap());
assert_eq!(actual, [8201, 5131, 5051]);
assert!(6846.4 < estimate[0] && estimate[0] < 6846.5);
@@ -311,8 +310,7 @@ mod tests {
#[test]
fn test_cardinality_large2() {
let (actual, estimate) =
test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap());
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap());
assert_eq!(actual, [777847, 482069, 482246]);
assert!(699437.4 < estimate[0] && estimate[0] < 699437.5);

View File

@@ -6,26 +6,15 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
libc.workspace = true
lock_api.workspace = true
rustc-hash.workspace = true
[dev-dependencies]
criterion = { workspace = true, features = ["html_reports"] }
rand = "0.9"
rand_distr = "0.5.1"
xxhash-rust = { version = "0.8.15", features = ["xxh3"] }
ahash.workspace = true
twox-hash = { version = "2.1.1" }
seahash = "4.1.0"
hashbrown = { git = "https://github.com/quantumish/hashbrown.git", rev = "6610e6d" }
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[[bench]]
name = "hmap_resize"
harness = false
[dev-dependencies]
rand = "0.9"
rand_distr = "0.5.1"

View File

@@ -1,330 +0,0 @@
use criterion::{BatchSize, BenchmarkId, Criterion, criterion_group, criterion_main};
use neon_shmem::hash::HashMapAccess;
use neon_shmem::hash::HashMapInit;
use neon_shmem::hash::entry::Entry;
use rand::distr::{Distribution, StandardUniform};
use rand::prelude::*;
use std::default::Default;
use std::hash::BuildHasher;
// Taken from bindings to C code
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
impl Distribution<FileCacheKey> for StandardUniform {
// questionable, but doesn't need to be good randomness
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> FileCacheKey {
FileCacheKey {
_spc_id: rng.random(),
_db_id: rng.random(),
_rel_number: rng.random(),
_fork_num: rng.random(),
_block_num: rng.random(),
}
}
}
#[derive(Clone, Debug)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
impl FileCacheEntry {
fn dummy() -> Self {
Self {
_offset: 0,
_access_count: 0,
_prev: std::ptr::null_mut(),
_next: std::ptr::null_mut(),
_state: [0; 8],
}
}
}
// Utilities for applying operations.
#[derive(Clone, Debug)]
struct TestOp<K, V>(K, Option<V>);
fn apply_op<K: Clone + std::hash::Hash + Eq, V, S: std::hash::BuildHasher>(
op: TestOp<K, V>,
map: &mut HashMapAccess<K, V, S>,
) {
let entry = map.entry(op.0);
match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
}
// Hash utilities
struct SeaRandomState {
k1: u64,
k2: u64,
k3: u64,
k4: u64,
}
impl std::hash::BuildHasher for SeaRandomState {
type Hasher = seahash::SeaHasher;
fn build_hasher(&self) -> Self::Hasher {
seahash::SeaHasher::with_seeds(self.k1, self.k2, self.k3, self.k4)
}
}
impl SeaRandomState {
fn new() -> Self {
let mut rng = rand::rng();
Self {
k1: rng.random(),
k2: rng.random(),
k3: rng.random(),
k4: rng.random(),
}
}
}
fn small_benchs(c: &mut Criterion) {
let mut group = c.benchmark_group("Small maps");
group.sample_size(10);
group.bench_function("small_rehash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_xxhash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(twox_hash::xxhash64::RandomState::default())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_ahash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(ahash::RandomState::default())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_seahash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(SeaRandomState::new())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.finish();
}
fn real_benchs(c: &mut Criterion) {
let mut group = c.benchmark_group("Realistic workloads");
group.sample_size(10);
group.bench_function("real_bulk_insert", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut rng = rand::rng();
b.iter_batched(
|| HashMapInit::new_resizeable(size, size * 2).attach_writer(),
|writer| {
for _ in 0..ideal_filled {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
let entry = writer.entry(key);
match entry {
Entry::Occupied(mut e) => {
std::hint::black_box(e.insert(val));
}
Entry::Vacant(e) => {
let _ = std::hint::black_box(e.insert(val).unwrap());
}
}
}
},
BatchSize::SmallInput,
)
});
group.bench_function("real_rehash", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut writer = HashMapInit::new_resizeable(size, size).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("real_rehash_hashbrown", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher;
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
}
while writer.len() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(k)
});
}
b.iter(|| unsafe {
writer.table.rehash_in_place(
&|table, index| {
hasher.hash_one(
&table
.bucket::<(FileCacheKey, FileCacheEntry)>(index)
.as_ref()
.0,
)
},
std::mem::size_of::<(FileCacheKey, FileCacheEntry)>(),
if std::mem::needs_drop::<(FileCacheKey, FileCacheEntry)>() {
Some(|ptr| std::ptr::drop_in_place(ptr as *mut (FileCacheKey, FileCacheEntry)))
} else {
None
},
)
});
});
for elems in [2, 4, 8, 16, 32, 64, 96, 112] {
group.bench_with_input(
BenchmarkId::new("real_rehash_varied", elems),
&elems,
|b, &size| {
let ideal_filled = size * 1_000_000;
let size = 125_000_000;
let mut writer = HashMapInit::new_resizeable(size, size).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
},
);
group.bench_with_input(
BenchmarkId::new("real_rehash_varied_hashbrown", elems),
&elems,
|b, &size| {
let ideal_filled = size * 1_000_000;
let size = 125_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher;
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
}
while writer.len() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(k)
});
}
b.iter(|| unsafe {
writer.table.rehash_in_place(
&|table, index| {
hasher.hash_one(
&table
.bucket::<(FileCacheKey, FileCacheEntry)>(index)
.as_ref()
.0,
)
},
std::mem::size_of::<(FileCacheKey, FileCacheEntry)>(),
if std::mem::needs_drop::<(FileCacheKey, FileCacheEntry)>() {
Some(|ptr| {
std::ptr::drop_in_place(ptr as *mut (FileCacheKey, FileCacheEntry))
})
} else {
None
},
)
});
},
);
}
group.finish();
}
criterion_group!(benches, small_benchs, real_benchs);
criterion_main!(benches);

View File

@@ -16,7 +16,6 @@
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::fmt::Debug;
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
@@ -57,22 +56,6 @@ pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
num_buckets: u32,
}
impl<'a, K, V, S> Debug for HashMapInit<'a, K, V, S>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HashMapInit")
.field("shmem_handle", &self.shmem_handle)
.field("shared_ptr", &self.shared_ptr)
.field("shared_size", &self.shared_size)
// .field("hasher", &self.hasher)
.field("num_buckets", &self.num_buckets)
.finish()
}
}
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
/// If a child process is launched with fork(), the child process should
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
@@ -88,20 +71,6 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
impl<'a, K, V, S> Debug for HashMapAccess<'a, K, V, S>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HashMapAccess")
.field("shmem_handle", &self.shmem_handle)
.field("shared_ptr", &self.shared_ptr)
// .field("hasher", &self.hasher)
.finish()
}
}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Change the 'hasher' used by the hash table.
///
@@ -329,7 +298,7 @@ where
/// Get a reference to the entry containing a key.
///
/// NB: This takes a write lock as there's no way to distinguish whether the intention
/// NB: THis takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key);

View File

@@ -1,6 +1,5 @@
//! Simple hash table with chaining.
use std::fmt::Debug;
use std::hash::Hash;
use std::mem::MaybeUninit;
@@ -18,19 +17,6 @@ pub(crate) struct Bucket<K, V> {
pub(crate) inner: Option<(K, V)>,
}
impl<K, V> Debug for Bucket<K, V>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bucket")
.field("next", &self.next)
.field("inner", &self.inner)
.finish()
}
}
/// Core hash table implementation.
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
@@ -45,22 +31,6 @@ pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) buckets_in_use: u32,
}
impl<'a, K, V> Debug for CoreHashMap<'a, K, V>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreHashMap")
.field("dictionary", &self.dictionary)
.field("buckets", &self.buckets)
.field("free_head", &self.free_head)
.field("alloc_limit", &self.alloc_limit)
.field("buckets_in_use", &self.buckets_in_use)
.finish()
}
}
/// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)]
pub struct FullError;

View File

@@ -61,10 +61,6 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
///
/// # Panics
/// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means
/// the entry was obtained via calling something like [`super::HashMapAccess::entry_at_bucket`].
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {

View File

@@ -21,7 +21,6 @@ use nix::unistd::ftruncate as nix_ftruncate;
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
/// future.
#[derive(Debug)]
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
@@ -36,7 +35,6 @@ pub struct ShmemHandle {
}
/// This is stored at the beginning in the shared memory area.
#[derive(Debug)]
struct SharedStruct {
max_size: usize,

View File

@@ -394,7 +394,7 @@ impl From<&OtelExporterConfig> for tracing_utils::ExportConfig {
tracing_utils::ExportConfig {
endpoint: Some(val.endpoint.clone()),
protocol: val.protocol.into(),
timeout: Some(val.timeout),
timeout: val.timeout,
}
}
}

View File

@@ -596,7 +596,6 @@ pub struct TimelineImportRequest {
pub timeline_id: TimelineId,
pub start_lsn: Lsn,
pub sk_set: Vec<NodeId>,
pub force_upsert: bool,
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]

View File

@@ -981,12 +981,12 @@ mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let key = Key {
field1: rng.random(),
field2: rng.random(),
field3: rng.random(),
field4: rng.random(),
field5: rng.random(),
field6: rng.random(),
field1: rng.r#gen(),
field2: rng.r#gen(),
field3: rng.r#gen(),
field4: rng.r#gen(),
field5: rng.r#gen(),
field6: rng.r#gen(),
};
assert_eq!(key, Key::from_str(&format!("{key}")).unwrap());

View File

@@ -443,9 +443,9 @@ pub struct ImportPgdataIdempotencyKey(pub String);
impl ImportPgdataIdempotencyKey {
pub fn random() -> Self {
use rand::Rng;
use rand::distr::Alphanumeric;
use rand::distributions::Alphanumeric;
Self(
rand::rng()
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(20)
.map(char::from)

View File

@@ -69,6 +69,22 @@ impl Hash for ShardIdentity {
}
}
/// Stripe size in number of pages
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
impl Default for ShardStripeSize {
fn default() -> Self {
DEFAULT_STRIPE_SIZE
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
/// Layout version: for future upgrades where we might change how the key->shard mapping works
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash, Debug)]
pub struct ShardLayout(u8);

View File

@@ -21,14 +21,6 @@ pub struct ReAttachRequest {
/// if the node already has a node_id set.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub register: Option<NodeRegisterRequest>,
/// Hadron: Optional flag to indicate whether the node is starting with an empty local disk.
/// Will be set to true if the node couldn't find any local tenant data on startup, could be
/// due to the node starting for the first time or due to a local SSD failure/disk wipe event.
/// The flag may be used by the storage controller to update its observed state of the world
/// to make sure that it sends explicit location_config calls to the node following the
/// re-attach request.
pub empty_local_disk: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug)]

View File

@@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData {
}
}
use rand::distr::{Distribution, StandardUniform};
impl Distribution<CancelKeyData> for StandardUniform {
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.random(),
cancel_key: rng.random(),
backend_pid: rng.r#gen(),
cancel_key: rng.r#gen(),
}
}
}

View File

@@ -155,10 +155,10 @@ pub struct ScramSha256 {
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.random_range(0x21u8..0x7e);
let mut v = rng.gen_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}

View File

@@ -74,6 +74,7 @@ impl Header {
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
@@ -144,7 +145,16 @@ impl Message {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => Message::NotificationResponse(NotificationResponseBody {}),
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
@@ -533,7 +543,28 @@ impl NoticeResponseBody {
}
}
pub struct NotificationResponseBody {}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct ParameterDescriptionBody {
storage: Bytes,

View File

@@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}

View File

@@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
@@ -221,18 +221,6 @@ impl Client {
&mut self.inner
}
pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
let (tx, rx) = mpsc::unbounded_channel();
let notices = RecordNotices { sender: tx, limit };
self.inner
.sender
.send(FrontendMessage::RecordNotices(notices))
.ok();
rx
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(

View File

@@ -3,17 +3,10 @@ use std::io;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use tokio::sync::mpsc::UnboundedSender;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
RecordNotices(RecordNotices),
}
pub struct RecordNotices {
pub sender: UnboundedSender<Box<str>>,
pub limit: usize,
}
pub enum BackendMessage {
@@ -40,11 +33,14 @@ impl FallibleIterator for BackendMessages {
pub struct PostgresCodec;
impl Encoder<Bytes> for PostgresCodec {
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item);
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
}
Ok(())
}
}

View File

@@ -1,9 +1,11 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
@@ -46,8 +48,8 @@ where
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters: _,
delayed_notice: _,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
@@ -70,7 +72,13 @@ where
secret_key,
);
let connection = Connection::new(stream, conn_tx, conn_rx);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
Ok((client, connection))
}

View File

@@ -3,7 +3,7 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::authentication::sasl;
@@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use crate::Error;
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, AuthKeys, Config};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::TlsStream;
@@ -25,7 +25,7 @@ pub struct StartupStream<S, T> {
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<Bytes> for StartupStream<S, T>
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
@@ -36,7 +36,7 @@ where
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
@@ -120,7 +120,10 @@ where
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream.send(buf.freeze()).await.map_err(Error::io)
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -188,7 +191,10 @@ where
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream.send(buf.freeze()).await.map_err(Error::io)
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
@@ -247,7 +253,10 @@ where
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
@@ -263,7 +272,10 @@ where
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,

View File

@@ -1,23 +1,22 @@
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, StreamExt, ready};
use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
use futures_util::{Sink, Stream, ready};
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
use tracing::trace;
use tracing::{info, trace};
use crate::Error;
use crate::codec::{
BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
};
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
#[derive(PartialEq, Debug)]
enum State {
@@ -34,18 +33,18 @@ enum State {
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
sender: PollSender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
notices: Option<RecordNotices>,
pending_response: Option<BackendMessages>,
pending_responses: VecDeque<BackendMessage>,
state: State,
}
pub enum Never {}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
@@ -53,44 +52,72 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
sender: mpsc::Sender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
sender: PollSender::new(sender),
receiver,
notices: None,
pending_response: None,
pending_responses,
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
loop {
let messages = match self.pending_response.take() {
Some(messages) => messages,
None => {
let message = match self.stream.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
Poll::Ready(Some(Ok(message))) => message,
};
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
self.handle_notice(body)?;
continue;
}
BackendMessage::Async(_) => continue,
BackendMessage::Normal { messages } => messages,
}
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
}
};
let messages = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal { messages } => messages,
};
match self.sender.poll_reserve(cx) {
Poll::Ready(Ok(())) => {
let _ = self.sender.send_item(messages);
@@ -99,7 +126,8 @@ where
return Poll::Ready(Err(Error::closed()));
}
Poll::Pending => {
self.pending_response = Some(messages);
self.pending_responses
.push_back(BackendMessage::Normal { messages });
trace!("poll_read: waiting on sender");
return Poll::Pending;
}
@@ -107,31 +135,6 @@ where
}
}
fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
let Some(notices) = &mut self.notices else {
return Ok(());
};
let mut fields = body.fields();
while let Some(field) = fields.next().map_err(Error::parse)? {
// loop until we find the message field
if field.type_() == b'M' {
// if the message field is within the limit, send it.
if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
match notices.sender.send(field.value().into()) {
// set the new limit.
Ok(()) => notices.limit = new_limit,
// closed.
Err(_) => self.notices = None,
}
}
break;
}
}
Ok(())
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
if self.receiver.is_closed() {
@@ -165,23 +168,21 @@ where
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(FrontendMessage::Raw(request))) => {
Poll::Ready(Some(request)) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
self.notices = Some(notices)
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) => {
trace!("poll_write: at eof, terminating");
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request.freeze())
.start_send(request)
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
@@ -230,17 +231,34 @@ where
}
}
fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let Poll::Pending = self.poll_read(cx)?;
if self.poll_write(cx)?.is_ready() {
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(()) = closing {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(()).
// if poll_write returned Ready(()), then we are waiting to read more data from postgres.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
@@ -262,9 +280,11 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.poll_message(cx)? {
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
}
Poll::Ready(Ok(()))
}
}

View File

@@ -8,6 +8,7 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -92,6 +93,21 @@ impl Notification {
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]

View File

@@ -43,7 +43,7 @@ itertools.workspace = true
sync_wrapper = { workspace = true, features = ["futures"] }
byteorder = "1.4"
rand.workspace = true
rand = "0.8.5"
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -129,7 +129,7 @@ impl AzureBlobStorage {
.pool_max_idle_per_host(conn_pool_size)
.build()
.expect("failed to build `reqwest` client");
Arc::new(client)
Arc::new(super::azure_reqwest_polyfill::Client(client))
}
pub fn relative_path_to_name(&self, path: &RemotePath) -> String {

View File

@@ -0,0 +1,96 @@
use std::{collections::HashMap, str::FromStr};
use azure_core::{
Body, HttpClient, Method, Request, Response, StatusCode,
error::{Error, ErrorKind, ResultExt},
headers::{HeaderName, HeaderValue, Headers},
};
use futures::TryStreamExt;
use tracing::{debug, warn};
#[derive(Debug)]
pub struct Client(pub reqwest::Client);
#[async_trait::async_trait]
impl HttpClient for Client {
async fn execute_request(&self, request: &Request) -> azure_core::Result<Response> {
let url = request.url().clone();
let method = request.method();
let mut req = self.0.request(try_from_method(*method)?, url.clone());
for (name, value) in request.headers().iter() {
req = req.header(name.as_str(), value.as_str());
}
let req = match request.body().clone() {
Body::Bytes(bytes) => req.body(bytes),
Body::SeekableStream(seekable_stream) => {
req.body(reqwest::Body::wrap_stream(seekable_stream))
}
};
let reqwest_request = req
.build()
.context(ErrorKind::Other, "failed to build `reqwest` request")?;
debug!("performing request {method} '{url}' with `reqwest`");
let rsp = self
.0
.execute(reqwest_request)
.await
.context(ErrorKind::Io, "failed to execute `reqwest` request")?;
let status = rsp.status();
let headers = to_headers(rsp.headers());
let body = Box::pin(
rsp.bytes_stream()
.map_err(|error| Error::new(ErrorKind::Io, error)),
);
Ok(Response::new(try_from_status(status)?, headers, body))
}
}
fn to_headers(map: &reqwest::header::HeaderMap) -> Headers {
let map = map
.iter()
.filter_map(|(k, v)| {
let key = k.as_str();
if let Ok(value) = v.to_str() {
Some((
HeaderName::from(key.to_owned()),
HeaderValue::from(value.to_owned()),
))
} else {
warn!("header value for `{key}` is not utf8");
None
}
})
.collect::<HashMap<_, _>>();
Headers::from(map)
}
fn try_from_method(method: Method) -> azure_core::Result<reqwest::Method> {
match method {
Method::Connect => Ok(reqwest::Method::CONNECT),
Method::Delete => Ok(reqwest::Method::DELETE),
Method::Get => Ok(reqwest::Method::GET),
Method::Head => Ok(reqwest::Method::HEAD),
Method::Options => Ok(reqwest::Method::OPTIONS),
Method::Patch => Ok(reqwest::Method::PATCH),
Method::Post => Ok(reqwest::Method::POST),
Method::Put => Ok(reqwest::Method::PUT),
Method::Trace => Ok(reqwest::Method::TRACE),
_ => reqwest::Method::from_str(method.as_ref()).map_kind(ErrorKind::DataConversion),
}
}
fn try_from_status(status: reqwest::StatusCode) -> azure_core::Result<StatusCode> {
let status = u16::from(status);
StatusCode::try_from(status).map_err(|_| {
Error::with_message(ErrorKind::DataConversion, || {
format!("invalid status code {status}")
})
})
}

View File

@@ -10,6 +10,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
mod azure_blob;
mod azure_reqwest_polyfill;
mod config;
mod error;
mod local_fs;

View File

@@ -81,7 +81,7 @@ impl UnreliableWrapper {
///
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
let mut attempts = self.attempts.lock().unwrap();
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
match attempts.entry(op) {
Entry::Occupied(mut e) => {
@@ -94,7 +94,7 @@ impl UnreliableWrapper {
/* BEGIN_HADRON */
// If there are more attempts to fail, fail the request by probability.
if (attempts_before_this < self.attempts_to_fail)
&& (rng.random_range(0..=100) < self.attempt_failure_probability)
&& (rng.gen_range(0..=100) < self.attempt_failure_probability)
{
let error =
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());

View File

@@ -208,7 +208,7 @@ async fn create_azure_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::rng().random::<u32>();
let random = rand::thread_rng().r#gen::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AzureContainer(AzureConfig {

View File

@@ -385,7 +385,7 @@ async fn create_s3_client(
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::rng().random::<u32>();
let random = rand::thread_rng().r#gen::<u32>();
let remote_storage_config = RemoteStorageConfig {
storage: RemoteStorageKind::AwsS3(S3Config {

View File

@@ -8,7 +8,7 @@ license.workspace = true
hyper0.workspace = true
opentelemetry = { workspace = true, features = ["trace"] }
opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true

View File

@@ -1,5 +1,11 @@
//! Helper functions to set up OpenTelemetry tracing.
//!
//! This comes in two variants, depending on whether you have a Tokio runtime available.
//! If you do, call `init_tracing()`. It sets up the trace processor and exporter to use
//! the current tokio runtime. If you don't have a runtime available, or you don't want
//! to share the runtime with the tracing tasks, call `init_tracing_without_runtime()`
//! instead. It sets up a dedicated single-threaded Tokio runtime for the tracing tasks.
//!
//! Example:
//!
//! ```rust,no_run
@@ -15,8 +21,7 @@
//! .with_writer(std::io::stderr);
//!
//! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces
//! let provider = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default());
//! let otlp_layer = provider.as_ref().map(tracing_utils::layer);
//! let otlp_layer = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default()).await;
//!
//! // Put it all together
//! tracing_subscriber::registry()
@@ -31,18 +36,16 @@
pub mod http;
pub mod perf_span;
use opentelemetry::KeyValue;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
pub use opentelemetry_otlp::{ExportConfig, Protocol};
use opentelemetry_sdk::trace::SdkTracerProvider;
use tracing::level_filters::LevelFilter;
use tracing::{Dispatch, Subscriber};
use tracing_subscriber::Layer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
pub type Provider = SdkTracerProvider;
/// Set up OpenTelemetry exporter, using configuration from environment variables.
///
/// `service_name` is set as the OpenTelemetry 'service.name' resource (see
@@ -67,7 +70,16 @@ pub type Provider = SdkTracerProvider;
/// If you need some other setting, please test if it works first. And perhaps
/// add a comment in the list above to save the effort of testing for the next
/// person.
pub fn init_tracing(service_name: &str, export_config: ExportConfig) -> Option<Provider> {
///
/// This doesn't block, but is marked as 'async' to hint that this must be called in
/// asynchronous execution context.
pub async fn init_tracing<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -77,14 +89,52 @@ pub fn init_tracing(service_name: &str, export_config: ExportConfig) -> Option<P
))
}
pub fn layer<S>(p: &Provider) -> impl Layer<S>
/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing
/// tasks.
pub fn init_tracing_without_runtime<S>(
service_name: &str,
export_config: ExportConfig,
) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
tracing_opentelemetry::layer().with_tracer(p.tracer("global"))
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
// The opentelemetry batch processor and the OTLP exporter needs a Tokio
// runtime. Create a dedicated runtime for them. One thread should be
// enough.
//
// (Alternatively, instead of batching, we could use the "simple
// processor", which doesn't need Tokio, and use "reqwest-blocking"
// feature for the OTLP exporter, which also doesn't need Tokio. However,
// batching is considered best practice, and also I have the feeling that
// the non-Tokio codepaths in the opentelemetry crate are less used and
// might be more buggy, so better to stay on the well-beaten path.)
//
// We leak the runtime so that it keeps running after we exit the
// function.
let runtime = Box::leak(Box::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("otlp runtime thread")
.worker_threads(1)
.build()
.unwrap(),
));
let _guard = runtime.enter();
Some(init_tracing_internal(
service_name.to_string(),
export_config,
))
}
fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> Provider {
fn init_tracing_internal<S>(service_name: String, export_config: ExportConfig) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
// Sets up exporter from the provided [`ExportConfig`] parameter.
// If the endpoint is not specified, it is loaded from the
// OTEL_EXPORTER_OTLP_ENDPOINT environment variable.
@@ -103,14 +153,22 @@ fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> P
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);
Provider::builder()
.with_batch_exporter(exporter)
.with_resource(
opentelemetry_sdk::Resource::builder()
.with_service_name(service_name)
.build(),
)
let tracer = opentelemetry_sdk::trace::TracerProvider::builder()
.with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio)
.with_resource(opentelemetry_sdk::Resource::new(vec![KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
service_name,
)]))
.build()
.tracer("global");
tracing_opentelemetry::layer().with_tracer(tracer)
}
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit.
pub fn shutdown_tracing() {
opentelemetry::global::shutdown_tracer_provider();
}
pub enum OtelEnablement {
@@ -118,17 +176,17 @@ pub enum OtelEnablement {
Enabled {
service_name: String,
export_config: ExportConfig,
runtime: &'static tokio::runtime::Runtime,
},
}
pub struct OtelGuard {
provider: Provider,
pub dispatch: Dispatch,
}
impl Drop for OtelGuard {
fn drop(&mut self) {
_ = self.provider.shutdown();
shutdown_tracing();
}
}
@@ -141,19 +199,22 @@ impl Drop for OtelGuard {
/// The lifetime of the guard should match taht of the application. On drop, it tears down the
/// OTEL infra.
pub fn init_performance_tracing(otel_enablement: OtelEnablement) -> Option<OtelGuard> {
match otel_enablement {
let otel_subscriber = match otel_enablement {
OtelEnablement::Disabled => None,
OtelEnablement::Enabled {
service_name,
export_config,
runtime,
} => {
let provider = init_tracing(&service_name, export_config)?;
let otel_layer = layer(&provider).with_filter(LevelFilter::INFO);
let otel_layer = runtime
.block_on(init_tracing(&service_name, export_config))
.with_filter(LevelFilter::INFO);
let otel_subscriber = tracing_subscriber::registry().with(otel_layer);
let dispatch = Dispatch::new(otel_subscriber);
let otel_dispatch = Dispatch::new(otel_subscriber);
Some(OtelGuard { dispatch, provider })
Some(otel_dispatch)
}
}
};
otel_subscriber.map(|dispatch| OtelGuard { dispatch })
}

View File

@@ -104,7 +104,7 @@ impl Id {
pub fn generate() -> Self {
let mut tli_buf = [0u8; 16];
rand::rng().fill(&mut tli_buf);
rand::thread_rng().fill(&mut tli_buf);
Id::from(tli_buf)
}

View File

@@ -364,37 +364,42 @@ impl MonotonicCounter<Lsn> for RecordLsn {
}
}
/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s.
/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s.
///
/// This is used by the `pagebench` pageserver benchmarking tool.
pub struct LsnSampler(<u64 as rand::distr::uniform::SampleUniform>::Sampler);
pub struct LsnSampler(<u64 as rand::distributions::uniform::SampleUniform>::Sampler);
impl rand::distr::uniform::SampleUniform for Lsn {
impl rand::distributions::uniform::SampleUniform for Lsn {
type Sampler = LsnSampler;
}
impl rand::distr::uniform::UniformSampler for LsnSampler {
impl rand::distributions::uniform::UniformSampler for LsnSampler {
type X = Lsn;
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
fn new<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new(low.borrow().0, high.borrow().0)
.map(Self)
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new(
low.borrow().0,
high.borrow().0,
),
)
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand::distr::uniform::Error>
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distr::uniform::SampleBorrow<Self::X> + Sized,
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
<u64 as rand::distr::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
),
)
.map(Self)
}
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> Self::X {

View File

@@ -25,12 +25,6 @@ pub struct ShardIndex {
pub shard_count: ShardCount,
}
/// Stripe size as number of pages.
///
/// NB: don't implement Default, so callers don't lazily use it by mistake. See DEFAULT_STRIPE_SIZE.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
/// Formatting helper, for generating the `shard_id` label in traces.
pub struct ShardSlug<'a>(&'a TenantShardId);
@@ -59,10 +53,6 @@ impl ShardCount {
pub const MAX: Self = Self(u8::MAX);
pub const MIN: Self = Self(0);
pub fn unsharded() -> Self {
ShardCount(0)
}
/// The internal value of a ShardCount may be zero, which means "1 shard, but use
/// legacy format for TenantShardId that excludes the shard suffix", also known
/// as [`TenantShardId::unsharded`].
@@ -187,12 +177,6 @@ impl std::fmt::Display for ShardCount {
}
}
impl std::fmt::Display for ShardStripeSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for ShardSlug<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(

View File

@@ -54,7 +54,6 @@ pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pageserver_page_api.workspace = true
peekable.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
@@ -67,7 +66,6 @@ postgres-types.workspace = true
posthog_client_lite.workspace = true
pprof.workspace = true
pq_proto.workspace = true
prost.workspace = true
rand.workspace = true
range-set-blaze = { version = "0.1.16", features = ["alloc"] }
regex.workspace = true

View File

@@ -11,8 +11,7 @@ use pageserver::tenant::layer_map::LayerMap;
use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc};
use pageserver_api::key::Key;
use pageserver_api::shard::TenantShardId;
use rand::prelude::{SeedableRng, StdRng};
use rand::seq::IndexedRandom;
use rand::prelude::{SeedableRng, SliceRandom, StdRng};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;

View File

@@ -14,11 +14,12 @@ use utils::logging::warn_slow;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_api::shard::ShardStripeSize;
use pageserver_page_api as page_api;
use pageserver_page_api::GetPageSplitter;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
/// Max number of concurrent clients per channel (i.e. TCP connection). New channels will be spun up
/// when full.
@@ -140,8 +141,8 @@ impl PageserverClient {
if !old.count.is_unsharded() && shard_spec.stripe_size != old.stripe_size {
return Err(anyhow!(
"can't change stripe size from {} to {}",
old.stripe_size.expect("always Some when sharded"),
shard_spec.stripe_size.expect("always Some when sharded")
old.stripe_size,
shard_spec.stripe_size
));
}
@@ -156,6 +157,23 @@ impl PageserverClient {
Ok(())
}
/// Returns whether a relation exists.
#[instrument(skip_all, fields(rel=%req.rel, lsn=%req.read_lsn))]
pub async fn check_rel_exists(
&self,
req: page_api::CheckRelExistsRequest,
) -> tonic::Result<page_api::CheckRelExistsResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.check_rel_exists(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
}
/// Returns the total size of a database, as # of bytes.
#[instrument(skip_all, fields(db_oid=%req.db_oid, lsn=%req.read_lsn))]
pub async fn get_db_size(
@@ -231,15 +249,13 @@ impl PageserverClient {
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
return Self::get_page_with_shard(req, shards.get(shard_id)?).await;
}
// Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and
// reassemble the responses.
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size);
let mut shard_requests = FuturesUnordered::new();
for (shard_id, shard_req) in splitter.drain_requests() {
@@ -249,14 +265,10 @@ impl PageserverClient {
}
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_id, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
splitter.add_response(shard_id, shard_response)?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
splitter.get_response()
}
/// Fetches pages on the given shard. Does not retry internally.
@@ -384,14 +396,12 @@ pub struct ShardSpec {
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size for these shards.
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
stripe_size: ShardStripeSize,
}
impl ShardSpec {
/// Creates a new shard spec with the given URLs and stripe size. All shards must be given.
/// The stripe size must be Some for sharded tenants, or None for unsharded tenants.
/// The stripe size may be omitted for unsharded tenants.
pub fn new(
urls: HashMap<ShardIndex, String>,
stripe_size: Option<ShardStripeSize>,
@@ -404,13 +414,11 @@ impl ShardSpec {
n => ShardCount::new(n as u8),
};
// Validate the stripe size.
// Determine the stripe size. It doesn't matter for unsharded tenants.
if stripe_size.is_none() && !count.is_unsharded() {
return Err(anyhow!("stripe size must be given for sharded tenants"));
}
if stripe_size.is_some() && count.is_unsharded() {
return Err(anyhow!("stripe size can't be given for unsharded tenants"));
}
let stripe_size = stripe_size.unwrap_or_default();
// Validate the shard spec.
for (shard_id, url) in &urls {
@@ -450,10 +458,8 @@ struct Shards {
///
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size.
///
/// INVARIANT: None for unsharded tenants, Some for sharded.
stripe_size: Option<ShardStripeSize>,
/// The stripe size. Only used for sharded tenants.
stripe_size: ShardStripeSize,
}
impl Shards {

View File

@@ -1,6 +1,6 @@
mod client;
mod pool;
mod retry;
mod split;
pub use client::{PageserverClient, ShardSpec};
pub use pageserver_api::shard::ShardStripeSize; // used in ShardSpec

View File

@@ -1,20 +1,19 @@
use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use crate::model::*;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::key_to_shard_number;
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
use pageserver_api::shard::{ShardStripeSize, key_to_shard_number};
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// Split requests by shard index.
requests: HashMap<ShardIndex, GetPageRequest>,
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
/// The response being assembled. Preallocated with empty pages, to be filled in.
response: GetPageResponse,
response: page_api::GetPageResponse,
/// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used
/// to assemble the response pages in the same order as the original request.
block_shards: Vec<ShardIndex>,
@@ -24,56 +23,45 @@ impl GetPageSplitter {
/// Checks if the given request only touches a single shard, and returns the shard ID. This is
/// the common case, so we check first in order to avoid unnecessary allocations and overhead.
pub fn for_single_shard(
req: &GetPageRequest,
req: &page_api::GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
stripe_size: ShardStripeSize,
) -> Option<ShardIndex> {
// Fast path: unsharded tenant.
if count.is_unsharded() {
return Ok(Some(ShardIndex::unsharded()));
return Some(ShardIndex::unsharded());
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Find the first page's shard, for comparison.
// Find the first page's shard, for comparison. If there are no pages, just return the first
// shard (caller likely checked already, otherwise the server will reject it).
let Some(&first_page) = req.block_numbers.first() else {
return Err(anyhow!("no block numbers in request"));
return Some(ShardIndex::new(ShardNumber(0), count));
};
let key = rel_block_to_key(req.rel, first_page);
let shard_number = key_to_shard_number(count, stripe_size, &key);
Ok(req
.block_numbers
req.block_numbers
.iter()
.skip(1) // computed above
.all(|&blkno| {
let key = rel_block_to_key(req.rel, blkno);
key_to_shard_number(count, stripe_size, &key) == shard_number
})
.then_some(ShardIndex::new(shard_number, count)))
.then_some(ShardIndex::new(shard_number, count))
}
/// Splits the given request.
pub fn split(
req: GetPageRequest,
req: page_api::GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
stripe_size: ShardStripeSize,
) -> Self {
// The caller should make sure we don't split requests unnecessarily.
debug_assert!(
Self::for_single_shard(&req, count, stripe_size)?.is_none(),
Self::for_single_shard(&req, count, stripe_size).is_none(),
"unnecessary request split"
);
if count.is_unsharded() {
return Err(anyhow!("unsharded tenant, no point in splitting request"));
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
};
// Split the requests by shard index.
let mut requests = HashMap::with_capacity(2); // common case
let mut block_shards = Vec::with_capacity(req.block_numbers.len());
@@ -84,7 +72,7 @@ impl GetPageSplitter {
requests
.entry(shard_id)
.or_insert_with(|| GetPageRequest {
.or_insert_with(|| page_api::GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
@@ -98,16 +86,16 @@ impl GetPageSplitter {
// Construct a response to be populated by shard responses. Preallocate empty page slots
// with the expected block numbers.
let response = GetPageResponse {
let response = page_api::GetPageResponse {
request_id: req.request_id,
status_code: GetPageStatusCode::Ok,
status_code: page_api::GetPageStatusCode::Ok,
reason: None,
rel: req.rel,
pages: req
.block_numbers
.into_iter()
.map(|block_number| {
Page {
page_api::Page {
block_number,
image: Bytes::new(), // empty page slot to be filled in
}
@@ -115,15 +103,17 @@ impl GetPageSplitter {
.collect(),
};
Ok(Self {
Self {
requests,
response,
block_shards,
})
}
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
pub fn drain_requests(&mut self) -> impl Iterator<Item = (ShardIndex, GetPageRequest)> {
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
self.requests.drain()
}
@@ -133,31 +123,22 @@ impl GetPageSplitter {
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: GetPageResponse,
) -> anyhow::Result<()> {
response: page_api::GetPageResponse,
) -> tonic::Result<()> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != GetPageStatusCode::Ok {
return Err(anyhow!(
if response.status_code != page_api::GetPageStatusCode::Ok {
return Err(tonic::Status::internal(format!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
response.reason.unwrap_or_default()
));
)));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
self.response.request_id, response.request_id
)));
}
// Place the shard response pages into the assembled response, in request order.
@@ -169,26 +150,27 @@ impl GetPageSplitter {
}
let Some(slot) = self.response.pages.get_mut(i) else {
return Err(anyhow!("no block_shards slot {i} for shard {shard_id}"));
return Err(tonic::Status::internal(format!(
"no block_shards slot {i} for shard {shard_id}"
)));
};
let Some(page) = pages.next() else {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"missing page {} in shard {shard_id} response",
slot.block_number
));
)));
};
if page.block_number != slot.block_number {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"shard {shard_id} returned wrong page at index {i}, expected {} got {}",
slot.block_number,
page.block_number
));
slot.block_number, page.block_number
)));
}
if !slot.image.is_empty() {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"shard {shard_id} returned duplicate page {} at index {i}",
slot.block_number
));
)));
}
*slot = page;
@@ -196,10 +178,10 @@ impl GetPageSplitter {
// Make sure we've consumed all pages from the shard response.
if let Some(extra_page) = pages.next() {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"shard {shard_id} returned extra page: {}",
extra_page.block_number
));
)));
}
Ok(())
@@ -207,18 +189,18 @@ impl GetPageSplitter {
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> anyhow::Result<GetPageResponse> {
pub fn get_response(self) -> tonic::Result<page_api::GetPageResponse> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {
return Err(anyhow!(
return Err(tonic::Status::internal(format!(
"missing page {} for shard {}",
page.block_number,
self.block_shards
.get(i)
.map(|s| s.to_string())
.unwrap_or_else(|| "?".to_string())
));
)));
}
}

View File

@@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()>
let cold_key_range = splitpoint..key_range.end;
for i in 0..cmd.num_records {
let chosen_range = if rand::rng().random_bool(0.9) {
let chosen_range = if rand::thread_rng().gen_bool(0.9) {
&hot_key_range
} else {
&cold_key_range

View File

@@ -300,9 +300,9 @@ impl MockTimeline {
key_range: &Range<Key>,
) -> anyhow::Result<()> {
crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]);
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
for _ in 0..num_records {
self.ingest_record(rng.random_range(key_range.clone()), len);
self.ingest_record(rng.gen_range(key_range.clone()), len);
self.wal_ingested += len;
}
Ok(())

View File

@@ -4,7 +4,7 @@ use anyhow::Context;
use clap::Parser;
use pageserver_api::key::Key;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize};
use pageserver_api::shard::{ShardCount, ShardStripeSize};
#[derive(Parser)]
pub(super) struct DescribeKeyCommand {
@@ -128,9 +128,7 @@ impl DescribeKeyCommand {
// seeing the sharding placement might be confusing, so leave it out unless shard
// count was given.
let stripe_size = stripe_size
.map(ShardStripeSize)
.unwrap_or(DEFAULT_STRIPE_SIZE);
let stripe_size = stripe_size.map(ShardStripeSize).unwrap_or_default();
println!(
"# placement with shard_count: {} and stripe_size: {}:",
shard_count.0, stripe_size.0

View File

@@ -17,11 +17,11 @@
// grpcurl \
// -plaintext \
// -H "neon-tenant-id: 7c4a1f9e3bd6470c8f3e21a65bd2e980" \
// -H "neon-shard-id: 0000" \
// -H "neon-shard-id: 0b10" \
// -H "neon-timeline-id: f08c4e9a2d5f76b1e3a7c2d8910f4b3e" \
// -H "authorization: Bearer $JWT" \
// -d '{"read_lsn": {"request_lsn": 100000000, "not_modified_since_lsn": 1}, "db_oid": 1}' \
// localhost:51051 page_api.PageService/GetDbSize
// -d '{"read_lsn": {"request_lsn": 1234567890}, "rel": {"spc_oid": 1663, "db_oid": 1234, "rel_number": 5678, "fork_number": 0}}'
// localhost:51051 page_api.PageService/CheckRelExists
// ```
//
// TODO: consider adding neon-compute-mode ("primary", "static", "replica").
@@ -38,8 +38,8 @@ package page_api;
import "google/protobuf/timestamp.proto";
service PageService {
// NB: unlike libpq, there is no CheckRelExists in gRPC, at the compute team's request. Instead,
// use GetRelSize with allow_missing=true to check existence.
// Returns whether a relation exists.
rpc CheckRelExists(CheckRelExistsRequest) returns (CheckRelExistsResponse);
// Fetches a base backup.
rpc GetBaseBackup (GetBaseBackupRequest) returns (stream GetBaseBackupResponseChunk);
@@ -97,6 +97,17 @@ message RelTag {
uint32 fork_number = 4;
}
// Checks whether a relation exists, at the given LSN. Only valid on shard 0,
// other shards will error.
message CheckRelExistsRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
}
message CheckRelExistsResponse {
bool exists = 1;
}
// Requests a base backup.
message GetBaseBackupRequest {
// The LSN to fetch the base backup at. 0 or absent means the latest LSN known to the Pageserver.
@@ -249,15 +260,10 @@ enum GetPageStatusCode {
message GetRelSizeRequest {
ReadLsn read_lsn = 1;
RelTag rel = 2;
// If true, return missing=true for missing relations instead of a NotFound error.
bool allow_missing = 3;
}
message GetRelSizeResponse {
// The number of blocks in the relation.
uint32 num_blocks = 1;
// If allow_missing=true, this is true for missing relations.
bool missing = 2;
}
// Requests an SLRU segment. Only valid on shard 0, other shards will error.

View File

@@ -69,6 +69,16 @@ impl Client {
Ok(Self { inner })
}
/// Returns whether a relation exists.
pub async fn check_rel_exists(
&mut self,
req: CheckRelExistsRequest,
) -> tonic::Result<CheckRelExistsResponse> {
let req = proto::CheckRelExistsRequest::from(req);
let resp = self.inner.check_rel_exists(req).await?.into_inner();
Ok(resp.into())
}
/// Fetches a base backup.
pub async fn get_base_backup(
&mut self,
@@ -104,8 +114,7 @@ impl Client {
Ok(resps.and_then(|resp| ready(GetPageResponse::try_from(resp).map_err(|err| err.into()))))
}
/// Returns the size of a relation as # of blocks, or None if allow_missing=true and the
/// relation does not exist.
/// Returns the size of a relation, as # of blocks.
pub async fn get_rel_size(
&mut self,
req: GetRelSizeRequest,

View File

@@ -19,9 +19,7 @@ pub mod proto {
}
mod client;
mod model;
mod split;
pub use client::Client;
mod model;
pub use model::*;
pub use split::GetPageSplitter;

View File

@@ -33,8 +33,6 @@ pub enum ProtocolError {
Invalid(&'static str, String),
#[error("required field '{0}' is missing")]
Missing(&'static str),
#[error("invalid combination of not_modified_lsn '{0}' and request_lsn '{1}'")]
InvalidLsns(Lsn, Lsn),
}
impl ProtocolError {
@@ -87,9 +85,9 @@ impl TryFrom<proto::ReadLsn> for ReadLsn {
return Err(ProtocolError::invalid("request_lsn", pb.request_lsn));
}
if pb.not_modified_since_lsn > pb.request_lsn {
return Err(ProtocolError::InvalidLsns(
Lsn(pb.not_modified_since_lsn),
Lsn(pb.request_lsn),
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
pb.not_modified_since_lsn,
));
}
Ok(Self {
@@ -141,6 +139,50 @@ impl From<RelTag> for proto::RelTag {
}
}
/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error.
#[derive(Clone, Copy, Debug)]
pub struct CheckRelExistsRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
type Error = ProtocolError;
fn try_from(pb: proto::CheckRelExistsRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
fn from(pb: proto::CheckRelExistsResponse) -> Self {
pb.exists
}
}
impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
fn from(exists: CheckRelExistsResponse) -> Self {
Self { exists }
}
}
/// Requests a base backup.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
@@ -665,8 +707,6 @@ impl From<GetPageStatusCode> for tonic::Code {
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
/// If true, return missing=true for missing relations instead of a NotFound error.
pub allow_missing: bool,
}
impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
@@ -679,7 +719,6 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
allow_missing: proto.allow_missing,
})
}
}
@@ -689,29 +728,21 @@ impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
allow_missing: request.allow_missing,
}
}
}
/// The size of a relation as number of blocks, or None if `allow_missing=true` and the relation
/// does not exist.
///
/// INVARIANT: never None if `allow_missing=false` (returns `NotFound` error instead).
pub type GetRelSizeResponse = Option<u32>;
pub type GetRelSizeResponse = u32;
impl From<proto::GetRelSizeResponse> for GetRelSizeResponse {
fn from(pb: proto::GetRelSizeResponse) -> Self {
(!pb.missing).then_some(pb.num_blocks)
fn from(proto: proto::GetRelSizeResponse) -> Self {
proto.num_blocks
}
}
impl From<GetRelSizeResponse> for proto::GetRelSizeResponse {
fn from(resp: GetRelSizeResponse) -> Self {
Self {
num_blocks: resp.unwrap_or_default(),
missing: resp.is_none(),
}
fn from(num_blocks: GetRelSizeResponse) -> Self {
Self { num_blocks }
}
}

View File

@@ -25,9 +25,6 @@ tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
axum.workspace = true
http.workspace = true
metrics.workspace = true
tonic.workspace = true
url.workspace = true

View File

@@ -188,9 +188,9 @@ async fn main_impl(
start_work_barrier.wait().await;
loop {
let (timeline, work) = {
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
let target = all_targets.choose(&mut rng).unwrap();
let lsn = target.lsn_range.clone().map(|r| rng.random_range(r));
let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
(target.timeline, Work { lsn })
};
let sender = work_senders.get(&timeline).unwrap();

View File

@@ -34,10 +34,6 @@ use crate::util::{request_stats, tokio_thread_local_stats};
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "false")]
grpc: bool,
#[clap(long, default_value = "false")]
grpc_stream: bool,
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
/// Pageserver connection string. Supports postgresql:// and grpc:// protocols.
@@ -82,9 +78,6 @@ pub(crate) struct Args {
#[clap(long)]
set_io_mode: Option<pageserver_api::models::virtual_file::IoMode>,
#[clap(long)]
only_relnode: Option<u32>,
/// Queue depth generated in each client.
#[clap(long, default_value = "1")]
queue_depth: NonZeroUsize,
@@ -99,31 +92,10 @@ pub(crate) struct Args {
#[clap(long, default_value = "1")]
batch_size: NonZeroUsize,
#[clap(long)]
only_relnode: Option<u32>,
targets: Option<Vec<TenantTimelineId>>,
#[clap(long, default_value = "100")]
pool_max_consumers: NonZeroUsize,
#[clap(long, default_value = "5")]
pool_error_threshold: NonZeroUsize,
#[clap(long, default_value = "5000")]
pool_connect_timeout: NonZeroUsize,
#[clap(long, default_value = "1000")]
pool_connect_backoff: NonZeroUsize,
#[clap(long, default_value = "60000")]
pool_max_idle_duration: NonZeroUsize,
#[clap(long, default_value = "0")]
max_delay_ms: usize,
#[clap(long, default_value = "0")]
percent_drops: usize,
#[clap(long, default_value = "0")]
percent_hangs: usize,
}
/// State shared by all clients
@@ -180,6 +152,7 @@ pub(crate) fn main(args: Args) -> anyhow::Result<()> {
main_impl(args, thread_local_stats)
})
}
async fn main_impl(
args: Args,
all_thread_local_stats: AllThreadLocalStats<request_stats::Stats>,
@@ -344,7 +317,6 @@ async fn main_impl(
let rps_period = args
.per_client_rate
.map(|rps_limit| Duration::from_secs_f64(1.0 / (rps_limit as f64)));
let make_worker: &dyn Fn(WorkerId) -> Pin<Box<dyn Send + Future<Output = ()>>> = &|worker_id| {
let ss = shared_state.clone();
let cancel = cancel.clone();
@@ -354,7 +326,8 @@ async fn main_impl(
.cloned()
.collect();
let weights =
rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap();
rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len()))
.unwrap();
Box::pin(async move {
let scheme = match Url::parse(&args.page_service_connstring) {
@@ -454,7 +427,7 @@ async fn run_worker(
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distr::weighted::WeightedIndex<i128>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
@@ -496,9 +469,9 @@ async fn run_worker(
}
// Pick a random page from a random relation.
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.random_range(r.start..r.end);
let key: i128 = rng.gen_range(r.start..r.end);
let (rel_tag, block_no) = key_to_block(key);
let mut blks = VecDeque::with_capacity(batch_size);
@@ -529,7 +502,7 @@ async fn run_worker(
// We assume that the entire batch can fit within the relation.
assert_eq!(blks.len(), batch_size, "incomplete batch");
let req_lsn = if rng.random_bool(args.req_latest_probability) {
let req_lsn = if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn

View File

@@ -7,7 +7,7 @@ use std::time::{Duration, Instant};
use pageserver_api::models::HistoricLayerInfo;
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
use rand::seq::IndexedMutRandom;
use rand::seq::SliceRandom;
use tokio::sync::{OwnedSemaphorePermit, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -260,7 +260,7 @@ async fn timeline_actor(
loop {
let layer_tx = {
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
timeline.layers.choose_mut(&mut rng).expect("no layers")
};
match layer_tx.try_send(permit.take().unwrap()) {

View File

@@ -126,6 +126,7 @@ fn main() -> anyhow::Result<()> {
Some(cfg) => tracing_utils::OtelEnablement::Enabled {
service_name: "pageserver".to_string(),
export_config: (&cfg.export_config).into(),
runtime: *COMPUTE_REQUEST_RUNTIME,
},
None => tracing_utils::OtelEnablement::Disabled,
};

View File

@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::net::IpAddr;
use futures::Future;
use pageserver_api::config::NodeMetadata;
@@ -17,7 +16,7 @@ use tokio_util::sync::CancellationToken;
use url::Url;
use utils::generation::Generation;
use utils::id::{NodeId, TimelineId};
use utils::{backoff, failpoint_support, ip_address};
use utils::{backoff, failpoint_support};
use crate::config::PageServerConf;
use crate::virtual_file::on_fatal_io_error;
@@ -28,7 +27,6 @@ pub struct StorageControllerUpcallClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
node_ip_addr: Option<IpAddr>,
cancel: CancellationToken,
}
@@ -42,7 +40,6 @@ pub trait StorageControllerUpcallApi {
fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> impl Future<
Output = Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError>,
> + Send;
@@ -94,18 +91,11 @@ impl StorageControllerUpcallClient {
);
}
// Intentionally panics if we encountered any errors parsing or reading the IP address.
// Note that if the required environment variable is not set, `read_node_ip_addr_from_env` returns `Ok(None)`
// instead of an error.
let node_ip_addr =
ip_address::read_node_ip_addr_from_env().expect("Error reading node IP address.");
Self {
http_client: client.build().expect("Failed to construct HTTP client"),
base_url: url,
node_id: conf.id,
cancel: cancel.clone(),
node_ip_addr,
}
}
@@ -156,7 +146,6 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
async fn re_attach(
&self,
conf: &PageServerConf,
empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
let url = self
.base_url
@@ -204,8 +193,8 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
listen_http_addr: m.http_host,
listen_http_port: m.http_port,
listen_https_port: m.https_port,
node_ip_addr: self.node_ip_addr,
availability_zone_id: az_id.expect("Checked above"),
node_ip_addr: None,
})
}
Err(e) => {
@@ -228,7 +217,6 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
let request = ReAttachRequest {
node_id: self.node_id,
register: register.clone(),
empty_local_disk: Some(empty_local_disk),
};
let response: ReAttachResponse = self

View File

@@ -768,7 +768,6 @@ mod test {
async fn re_attach(
&self,
_conf: &PageServerConf,
_empty_local_disk: bool,
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
unimplemented!()
}

View File

@@ -155,7 +155,7 @@ impl FeatureResolver {
);
let tenant_properties = PerTenantProperties {
remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)),
remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)),
}
.into_posthog_properties();

View File

@@ -16,8 +16,7 @@ use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt as _};
use futures::{FutureExt, Stream};
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -36,8 +35,8 @@ use pageserver_api::pagestream_api::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api as page_api;
use pageserver_page_api::proto;
use pageserver_page_api::{self as page_api, GetPageSplitter};
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -444,7 +443,6 @@ impl TimelineHandles {
handles: Default::default(),
}
}
async fn get(
&mut self,
tenant_id: TenantId,
@@ -471,13 +469,6 @@ impl TimelineHandles {
fn tenant_id(&self) -> Option<TenantId> {
self.wrapper.tenant_id.get().copied()
}
/// Returns whether a child shard exists locally for the given shard.
fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
self.wrapper
.tenant_manager
.has_child_shard(tenant_id, shard_index)
}
}
pub(crate) struct TenantManagerWrapper {
@@ -1645,10 +1636,9 @@ impl PageServerHandler {
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
(
vec![
Self::handle_get_nblocks_request(&shard, &req, false, &ctx)
Self::handle_get_nblocks_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| msg.expect("allow_missing=false"))
.map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
@@ -2313,16 +2303,12 @@ impl PageServerHandler {
Ok(PagestreamExistsResponse { req: *req, exists })
}
/// If `allow_missing` is true, returns None instead of Err on missing relations. Otherwise,
/// never returns None. It is only supported by the gRPC protocol, so we pass it separately to
/// avoid changing the libpq protocol types.
#[instrument(skip_all, fields(shard_id))]
async fn handle_get_nblocks_request(
timeline: &Timeline,
req: &PagestreamNblocksRequest,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<Option<PagestreamNblocksResponse>, PageStreamError> {
) -> Result<PagestreamNblocksResponse, PageStreamError> {
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
timeline,
@@ -2334,25 +2320,20 @@ impl PageServerHandler {
.await?;
let n_blocks = timeline
.get_rel_size_in_reldir(
.get_rel_size(
req.rel,
Version::LsnRange(LsnRange {
effective_lsn: lsn,
request_lsn: req.hdr.request_lsn,
}),
None,
allow_missing,
ctx,
)
.await?;
let Some(n_blocks) = n_blocks else {
return Ok(None);
};
Ok(Some(PagestreamNblocksResponse {
Ok(PagestreamNblocksResponse {
req: *req,
n_blocks,
}))
})
}
#[instrument(skip_all, fields(shard_id))]
@@ -3237,25 +3218,13 @@ where
pub struct GrpcPageServiceHandler {
tenant_manager: Arc<TenantManager>,
ctx: RequestContext,
/// Cancelled to shut down the server. Tonic will shut down in response to this, but wait for
/// in-flight requests to complete. Any tasks we spawn ourselves must respect this token.
cancel: CancellationToken,
/// Any tasks we spawn ourselves should clone this gate guard, so that we can wait for them to
/// complete during shutdown. Request handlers implicitly hold this guard already.
gate_guard: GateGuard,
/// `get_vectored` concurrency setting.
get_vectored_concurrent_io: GetVectoredConcurrentIo,
}
impl GrpcPageServiceHandler {
/// Spawns a gRPC server for the page service.
///
/// Returns a `CancellableTask` handle that can be used to shut down the server. It waits for
/// any in-flight requests and tasks to complete first.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn(
@@ -3265,15 +3234,12 @@ impl GrpcPageServiceHandler {
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
// Set up a cancellation token for shutting down the server, and a gate to wait for all
// requests and spawned tasks to complete.
let cancel = CancellationToken::new();
let gate = Gate::default();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
@@ -3304,7 +3270,6 @@ impl GrpcPageServiceHandler {
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
cancel: cancel.clone(),
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
@@ -3341,20 +3306,19 @@ impl GrpcPageServiceHandler {
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task. It runs until the cancellation token fires and in-flight requests and
// tasks complete. The `CancellableTask` will wait for the task's join handle, which
// implicitly waits for the gate to close.
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc pageservice listener",
"grpc listener",
async move {
server
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await?;
// Server exited cleanly. All requests should have completed by now. Wait for any
// spawned tasks to complete as well (e.g. IoConcurrency sidecars) via the gate.
gate.close().await;
anyhow::Ok(())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
@@ -3387,9 +3351,17 @@ impl GrpcPageServiceHandler {
}
}
/// Acquires a timeline handle for the given request. The shard index must match a local shard.
/// Acquires a timeline handle for the given request.
///
/// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`].
/// TODO: during shard splits, the compute may still be sending requests to the parent shard
/// until the entire split is committed and the compute is notified. Consider installing a
/// temporary shard router from the parent to the children while the split is in progress.
///
/// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage
/// the TimelineHandles lifecycle.
///
/// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid
/// the unnecessary overhead.
async fn get_request_timeline(
&self,
req: &tonic::Request<impl Any>,
@@ -3398,62 +3370,11 @@ impl GrpcPageServiceHandler {
let shard_index = *extract::<ShardIndex>(req);
let shard_selector = ShardSelector::Known(shard_index);
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
TimelineHandles::new(self.tenant_manager.clone())
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await
}
/// Acquires a timeline handle for the given request, which must be for shard zero.
///
/// NB: during an ongoing shard split, the compute will keep talking to the parent shard until
/// the split is committed, but the parent shard may have been removed in the meanwhile. In that
/// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`].
///
/// TODO: revamp the split protocol to avoid this child routing.
async fn get_shard_zero_request_timeline(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, tonic::Status> {
let ttid = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
if shard_index.shard_number.0 != 0 {
return Err(tonic::Status::invalid_argument(format!(
"request must use shard zero (requested shard {shard_index})",
)));
}
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
//
// TODO: this does internal retries, which will delay requests during shard splits (we won't
// look for the child until the parent's retries are exhausted). Don't do that.
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
match handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await
{
Ok(timeline) => Ok(timeline),
Err(err) => {
// We may be in the middle of a shard split. Try to find a child shard 0.
if let Ok(timeline) = handles
.get(ttid.tenant_id, ttid.timeline_id, ShardSelector::Zero)
.await
&& timeline.get_shard_index().shard_count > shard_index.shard_count
{
return Ok(timeline);
}
Err(err.into())
}
}
}
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
/// Only errors if the timeline is shutting down.
///
@@ -3485,22 +3406,28 @@ impl GrpcPageServiceHandler {
/// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send
/// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or
/// split them up in the client or server.
#[instrument(skip_all, fields(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
))]
#[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))]
async fn get_page(
ctx: &RequestContext,
timeline: Handle<TenantManagerTypes>,
req: page_api::GetPageRequest,
timeline: &WeakHandle<TenantManagerTypes>,
req: proto::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
) -> Result<proto::GetPageResponse, tonic::Status> {
let received_at = Instant::now();
let timeline = timeline.upgrade()?;
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
let req = page_api::GetPageRequest::try_from(req)?;
span_record!(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard
let effective_lsn = PageServerHandler::effective_request_lsn(
&timeline,
@@ -3575,105 +3502,13 @@ impl GrpcPageServiceHandler {
};
}
Ok(resp)
}
/// Processes a GetPage request when there is a potential shard split in progress. We have to
/// reroute the request any local child shards, and split batch requests that straddle multiple
/// child shards.
///
/// Parent shards are split and removed incrementally, but the compute is only notified once the
/// entire split commits, which can take several minutes. In the meanwhile, the compute will be
/// sending requests to the parent shard.
///
/// TODO: add test infrastructure to provoke this situation frequently and for long periods of
/// time, to properly exercise it.
///
/// TODO: revamp the split protocol to avoid this, e.g.:
/// * Keep the parent shard until the split commits and the compute is notified.
/// * Notify the compute about each subsplit.
/// * Return an error that updates the compute's shard map.
#[instrument(skip_all)]
async fn maybe_split_get_page(
ctx: &RequestContext,
handles: &mut TimelineHandles,
ttid: TenantTimelineId,
parent: ShardIndex,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
// Check the first page to see if we have any child shards at all. Otherwise, the compute is
// just talking to the wrong Pageserver. If the parent has been split, the shard now owning
// the page must have a higher shard count.
let timeline = handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])),
)
.await?;
let shard_id = timeline.get_shard_identity();
if shard_id.count <= parent.shard_count {
return Err(HandleUpgradeError::ShutDown.into()); // emulate original error
}
// Fast path: the request fits in a single shard.
if let Some(shard_index) =
GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
// We got the shard ID from the first page, so these must be equal.
assert_eq!(shard_index.shard_number, shard_id.number);
assert_eq!(shard_index.shard_count, shard_id.count);
return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await;
}
// The request spans multiple shards; split it and dispatch parallel requests. All pages
// were originally in the parent shard, and during a split all children are local, so we
// expect to find local shards for all pages.
let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_index, shard_req) in splitter.drain_requests() {
let timeline = handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await?;
let future = Self::get_page(
ctx,
timeline,
shard_req,
io_concurrency.clone(),
received_at,
)
.map(move |result| result.map(|resp| (shard_index, resp)));
shard_requests.push(future);
}
while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_index, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
Ok(resp.into())
}
}
/// Implements the gRPC page service.
///
/// On client disconnect (e.g. timeout or client shutdown), Tonic will drop the request handler
/// futures, so the read path must be cancellation-safe. On server shutdown, Tonic will wait for
/// in-flight requests to complete.
///
/// TODO: cancellation.
/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code.
#[tonic::async_trait]
impl proto::PageService for GrpcPageServiceHandler {
@@ -3684,6 +3519,39 @@ impl proto::PageService for GrpcPageServiceHandler {
type GetPagesStream =
Pin<Box<dyn Stream<Item = Result<proto::GetPageResponse, tonic::Status>> + Send>>;
#[instrument(skip_all, fields(rel, lsn))]
async fn check_rel_exists(
&self,
req: tonic::Request<proto::CheckRelExistsRequest>,
) -> Result<tonic::Response<proto::CheckRelExistsResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?;
span_record!(rel=%req.rel, lsn=%req.read_lsn);
let req = PagestreamExistsRequest {
hdr: Self::make_hdr(req.read_lsn, None),
rel: req.rel,
};
// Execute the request and convert the response.
let _timer = Self::record_op_start_and_throttle(
&timeline,
metrics::SmgrQueryType::GetRelExists,
received_at,
)
.await?;
let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?;
let resp: page_api::CheckRelExistsResponse = resp.exists;
Ok(tonic::Response::new(resp.into()))
}
#[instrument(skip_all, fields(lsn))]
async fn get_base_backup(
&self,
@@ -3693,10 +3561,11 @@ impl proto::PageService for GrpcPageServiceHandler {
// to be the sweet spot where throughput is saturated.
const CHUNK_SIZE: usize = 256 * 1024;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate the request and decorate the span.
Self::ensure_shard_zero(&timeline)?;
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
@@ -3724,14 +3593,8 @@ impl proto::PageService for GrpcPageServiceHandler {
// Spawn a task to run the basebackup.
let span = Span::current();
let gate_guard = self
.gate_guard
.try_clone()
.map_err(|_| tonic::Status::unavailable("shutting down"))?;
let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE);
let jh = tokio::spawn(async move {
let _gate_guard = gate_guard; // keep gate open until task completes
let gzip_level = match req.compression {
page_api::BaseBackupCompression::None => None,
// NB: using fast compression because it's on the critical path for compute
@@ -3812,10 +3675,11 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn);
@@ -3844,101 +3708,43 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// Extract the timeline from the request and check that it exists.
//
// NB: during shard splits, the compute may still send requests to the parent shard. We'll
// reroute requests to the child shards below, but we also detect the common cases here
// where either the shard exists or no shards exist at all. If we have a child shard, we
// can't acquire a weak handle because we don't know which child shard to use yet.
//
// TODO: TimelineHandles.get() does internal retries, which will delay requests during shard
// splits. It shouldn't.
let ttid = *extract::<TenantTimelineId>(&req);
let shard_index = *extract::<ShardIndex>(&req);
let shard_selector = ShardSelector::Known(shard_index);
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
let timeline = match handles
.get(
ttid.tenant_id,
ttid.timeline_id,
ShardSelector::Known(shard_index),
)
.await
{
// The timeline shard exists. Keep a weak handle to reuse for each request.
Ok(timeline) => Some(timeline.downgrade()),
// The shard doesn't exist, but a child shard does. We'll reroute requests later.
Err(_) if handles.has_child_shard(ttid.tenant_id, shard_index) => None,
// Failed to fetch the timeline, and no child shard exists. Error out.
Err(err) => return Err(err.into()),
};
handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?;
// Spawn an IoConcurrency sidecar, if enabled.
let gate_guard = self
.gate_guard
.try_clone()
.map_err(|_| tonic::Status::unavailable("shutting down"))?;
let Ok(gate_guard) = self.gate_guard.try_clone() else {
return Err(tonic::Status::unavailable("shutting down"));
};
let io_concurrency =
IoConcurrency::spawn_from_conf(self.get_vectored_concurrent_io, gate_guard);
// Construct the GetPageRequest stream handler.
// Spawn a task to handle the GetPageRequest stream.
let span = Span::current();
let ctx = self.ctx.attached_child();
let cancel = self.cancel.clone();
let mut reqs = req.into_inner();
let resps = async_stream::try_stream! {
loop {
// Wait for the next client request.
//
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
biased;
_ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")),
result = reqs.message() => match result {
Ok(Some(req)) => Ok(req),
Ok(None) => break, // client closed the stream
Err(err) => Err(err),
},
}?;
let received_at = Instant::now();
let timeline = handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?
.downgrade();
while let Some(req) = reqs.message().await? {
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
// Process the request, using a closure to capture errors.
let process_request = async || {
let req = page_api::GetPageRequest::try_from(req)?;
// Fast path: use the pre-acquired timeline handle.
if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) {
return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at)
.instrument(span.clone()) // propagate request span
.await
}
// The timeline handle is stale. During shard splits, the compute may still be
// sending requests to the parent shard. Try to re-route requests to the child
// shards, and split any batch requests that straddle multiple child shards.
Self::maybe_split_get_page(
&ctx,
&mut handles,
ttid,
shard_index,
req,
io_concurrency.clone(),
received_at,
)
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
.instrument(span.clone()) // propagate request span
.await
};
// Return the response. Convert per-request errors to GetPageResponses if
// appropriate, or terminate the stream with a tonic::Status.
yield match process_request().await {
Ok(resp) => resp.into(),
.await;
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
Err(status) => {
// Log the error, since ObservabilityLayer won't see stream errors.
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
@@ -3952,20 +3758,20 @@ impl proto::PageService for GrpcPageServiceHandler {
Ok(tonic::Response::new(Box::pin(resps)))
}
#[instrument(skip_all, fields(rel, lsn, allow_missing))]
#[instrument(skip_all, fields(rel, lsn))]
async fn get_rel_size(
&self,
req: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
span_record!(rel=%req.rel, lsn=%req.read_lsn, allow_missing=%req.allow_missing);
span_record!(rel=%req.rel, lsn=%req.read_lsn);
let req = PagestreamNblocksRequest {
hdr: Self::make_hdr(req.read_lsn, None),
@@ -3980,11 +3786,8 @@ impl proto::PageService for GrpcPageServiceHandler {
)
.await?;
let resp =
PageServerHandler::handle_get_nblocks_request(&timeline, &req, allow_missing, &ctx)
.await?;
let resp: page_api::GetRelSizeResponse = resp.map(|resp| resp.n_blocks);
let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?;
let resp: page_api::GetRelSizeResponse = resp.n_blocks;
Ok(tonic::Response::new(resp.into()))
}
@@ -3994,7 +3797,7 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_shard_zero_request_timeline(&req).await?;
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
@@ -4028,10 +3831,6 @@ impl proto::PageService for GrpcPageServiceHandler {
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
// TODO: this won't work during shard splits, as the request is directed at a specific shard
// but the parent shard is removed before the split commits and the compute is notified
// (which can take several minutes for large tenants). That's also the case for the libpq
// implementation, so we keep the behavior for now.
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);

View File

@@ -286,10 +286,6 @@ impl Timeline {
/// Like [`Self::get_rel_page_at_lsn`], but returns a batch of pages.
///
/// The ordering of the returned vec corresponds to the ordering of `pages`.
///
/// NB: the read path must be cancellation-safe. The Tonic gRPC service will drop the future
/// if the client goes away (e.g. due to timeout or cancellation).
/// TODO: verify that it actually is cancellation-safe.
pub(crate) async fn get_rel_page_at_lsn_batched(
&self,
pages: impl ExactSizeIterator<Item = (&RelTag, &BlockNumber, LsnRange, RequestContext)>,
@@ -504,9 +500,8 @@ impl Timeline {
for rel in rels {
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), false, ctx)
.await?
.expect("allow_missing=false");
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -522,16 +517,10 @@ impl Timeline {
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
Ok(self
.get_rel_size_in_reldir(tag, version, None, false, ctx)
.await?
.expect("allow_missing=false"))
self.get_rel_size_in_reldir(tag, version, None, ctx).await
}
/// Get size of a relation file. If `allow_missing` is true, returns None for missing relations,
/// otherwise errors.
///
/// INVARIANT: never returns None if `allow_missing=false`.
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
@@ -539,9 +528,8 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
allow_missing: bool,
ctx: &RequestContext,
) -> Result<Option<BlockNumber>, PageReconstructError> {
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
RelationError::InvalidRelnode.into(),
@@ -549,15 +537,7 @@ impl Timeline {
}
if let Some(nblocks) = self.get_cached_rel_size(&tag, version) {
return Ok(Some(nblocks));
}
if allow_missing
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
{
return Ok(None);
return Ok(nblocks);
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
@@ -569,7 +549,7 @@ impl Timeline {
// FSM, and smgrnblocks() on it immediately afterwards,
// without extending it. Tolerate that by claiming that
// any non-existent FSM fork has size 0.
return Ok(Some(0));
return Ok(0);
}
let key = rel_size_to_key(tag);
@@ -578,7 +558,7 @@ impl Timeline {
self.update_cached_rel_size(tag, version, nblocks);
Ok(Some(nblocks))
Ok(nblocks)
}
/// Does the relation exist?
@@ -2928,8 +2908,9 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
mod tests {
use hex_literal::hex;
use pageserver_api::models::ShardParameters;
use pageserver_api::shard::ShardStripeSize;
use utils::id::TimelineId;
use utils::shard::{ShardCount, ShardNumber, ShardStripeSize};
use utils::shard::{ShardCount, ShardNumber};
use super::*;
use crate::DEFAULT_PG_VERSION;

View File

@@ -6161,11 +6161,11 @@ mod tests {
use pageserver_api::keyspace::KeySpaceRandomAccum;
use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease};
use pageserver_compaction::helpers::overlaps_with;
use rand::Rng;
#[cfg(feature = "testing")]
use rand::SeedableRng;
#[cfg(feature = "testing")]
use rand::rngs::StdRng;
use rand::{Rng, thread_rng};
#[cfg(feature = "testing")]
use std::ops::Range;
use storage_layer::{IoConcurrency, PersistentLayerKey};
@@ -6286,8 +6286,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -6330,8 +6330,8 @@ mod tests {
while lsn < lsn_range.end {
let mut key = key_range.start;
while key < key_range.end {
let gap = random.random_range(1..=100) <= spec.gap_chance;
let will_init = random.random_range(1..=100) <= spec.will_init_chance;
let gap = random.gen_range(1..=100) <= spec.gap_chance;
let will_init = random.gen_range(1..=100) <= spec.will_init_chance;
if gap {
continue;
@@ -7808,7 +7808,7 @@ mod tests {
for _ in 0..50 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = rand::rng().random_range(0..NUM_KEYS);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7897,7 +7897,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = rand::rng().random_range(0..NUM_KEYS);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -7965,7 +7965,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = rand::rng().random_range(0..NUM_KEYS);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
test_key.field6 = blknum as u32;
let mut writer = tline.writer().await;
writer
@@ -8229,7 +8229,7 @@ mod tests {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = rand::rng().random_range(0..NUM_KEYS);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -8502,7 +8502,7 @@ mod tests {
for iter in 1..=10 {
for _ in 0..NUM_KEYS {
lsn = Lsn(lsn.0 + 0x10);
let blknum = rand::rng().random_range(0..NUM_KEYS);
let blknum = thread_rng().gen_range(0..NUM_KEYS);
test_key.field6 = (blknum * STEP) as u32;
let mut writer = tline.writer().await;
writer
@@ -11291,10 +11291,10 @@ mod tests {
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_read_path() -> anyhow::Result<()> {
use rand::seq::IndexedRandom;
use rand::seq::SliceRandom;
let seed = if cfg!(feature = "fuzz-read-path") {
let seed: u64 = rand::rng().random();
let seed: u64 = thread_rng().r#gen();
seed
} else {
// Use a hard-coded seed when not in fuzzing mode.
@@ -11308,8 +11308,8 @@ mod tests {
let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") {
const QUERIES: u64 = 5000;
let will_init_chance: u8 = random.random_range(0..=10);
let gap_chance: u8 = random.random_range(0..=50);
let will_init_chance: u8 = random.gen_range(0..=10);
let gap_chance: u8 = random.gen_range(0..=50);
(QUERIES, will_init_chance, gap_chance)
} else {
@@ -11410,8 +11410,7 @@ mod tests {
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key =
start_key.add(random.random_range(0..KEY_DIMENSION_SIZE));
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
if used_keys.contains(&selected_key)
@@ -11426,7 +11425,7 @@ mod tests {
.add_key(selected_key);
used_keys.insert(selected_key);
let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE;
let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE;
if pick_next {
selected_key = selected_key.next();
} else {

View File

@@ -535,8 +535,8 @@ pub(crate) mod tests {
}
pub(crate) fn random_array(len: usize) -> Vec<u8> {
let mut rng = rand::rng();
(0..len).map(|_| rng.random()).collect::<_>()
let mut rng = rand::thread_rng();
(0..len).map(|_| rng.r#gen()).collect::<_>()
}
#[tokio::test]
@@ -588,9 +588,9 @@ pub(crate) mod tests {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let blobs = (0..1024)
.map(|_| {
let mut sz: u16 = rng.random();
let mut sz: u16 = rng.r#gen();
// Make 50% of the arrays small
if rng.random() {
if rng.r#gen() {
sz &= 63;
}
random_array(sz.into())

View File

@@ -1090,7 +1090,7 @@ pub(crate) mod tests {
const NUM_KEYS: usize = 100000;
let mut all_data: BTreeMap<u128, u64> = BTreeMap::new();
for idx in 0..NUM_KEYS {
let u: f64 = rand::rng().random_range(0.0..1.0);
let u: f64 = rand::thread_rng().gen_range(0.0..1.0);
let t = -(f64::ln(u));
let key_int = (t * 1000000.0) as u128;
@@ -1116,7 +1116,7 @@ pub(crate) mod tests {
// Test get() operations on random keys, most of which will not exist
for _ in 0..100000 {
let key_int = rand::rng().random::<u128>();
let key_int = rand::thread_rng().r#gen::<u128>();
let search_key = u128::to_be_bytes(key_int);
assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned());
}

View File

@@ -508,8 +508,8 @@ mod tests {
let write_nbytes = cap * 2 + cap / 2;
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
.take(write_nbytes)
.collect();
@@ -565,8 +565,8 @@ mod tests {
let cap = writer.mutable().capacity();
drop(writer);
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
.take(cap * 2 + cap / 2)
.collect();
@@ -614,8 +614,8 @@ mod tests {
let cap = mutable.capacity();
let align = mutable.align();
drop(writer);
let content: Vec<u8> = rand::rng()
.sample_iter(rand::distr::StandardUniform)
let content: Vec<u8> = rand::thread_rng()
.sample_iter(rand::distributions::Standard)
.take(cap * 2 + cap / 2)
.collect();

View File

@@ -19,7 +19,7 @@ use pageserver_api::shard::{
};
use pageserver_api::upcall_api::ReAttachResponseTenant;
use rand::Rng;
use rand::distr::Alphanumeric;
use rand::distributions::Alphanumeric;
use remote_storage::TimeoutOrCancel;
use sysinfo::SystemExt;
use tokio::fs;
@@ -218,7 +218,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
std::io::ErrorKind::InvalidInput,
"Path must be absolute",
))?;
let rand_suffix = rand::rng()
let rand_suffix = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
@@ -328,7 +328,7 @@ fn emergency_generations(
LocationMode::Attached(alc) => TenantStartupMode::Attached((
alc.attach_mode,
alc.generation,
lc.shard.stripe_size,
ShardStripeSize::default(),
)),
LocationMode::Secondary(_) => TenantStartupMode::Secondary,
},
@@ -352,8 +352,7 @@ async fn init_load_generations(
let client = StorageControllerUpcallClient::new(conf, cancel);
info!("Calling {} API to re-attach tenants", client.base_url());
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
let empty_local_disk = tenant_confs.is_empty();
match client.re_attach(conf, empty_local_disk).await {
match client.re_attach(conf).await {
Ok(tenants) => tenants
.into_iter()
.flat_map(|(id, rart)| {
@@ -826,18 +825,6 @@ impl TenantManager {
peek_slot.is_some()
}
/// Returns whether a local slot exists for a child shard of the given tenant and shard count.
/// Note that this just checks for a shard with a larger shard count, and it may not be a
/// direct child of the given shard.
pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
match &*self.tenants.read().unwrap() {
TenantsMap::Initializing => false,
TenantsMap::Open(slots) | TenantsMap::ShuttingDown(slots) => slots
.range(TenantShardId::tenant_range(tenant_id))
.any(|(tsid, _)| tsid.shard_count > shard_index.shard_count),
}
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))]
pub(crate) async fn upsert_location(
&self,

Some files were not shown because too many files have changed in this diff Show More