Compare commits

..

6 Commits

Author SHA1 Message Date
Conrad Ludgate
a4a72c8075 support TLS for sql-over-http 2024-12-18 12:17:31 +00:00
Conrad Ludgate
90ce4f3002 properly remove clones 2024-12-18 12:06:57 +00:00
Conrad Ludgate
b79a1dd337 reduce cloning 2024-12-18 11:57:10 +00:00
Conrad Ludgate
bbc799ce77 chore(proxy): pre-load native tls certificates and propagate compute client config 2024-12-18 09:29:42 +00:00
Conrad Ludgate
cd0924c686 fmt 2024-12-17 19:43:23 +00:00
Conrad Ludgate
2548926ea6 chore(proxy): fully remove allow-self-signed-compute flag 2024-12-17 19:32:42 +00:00
114 changed files with 887 additions and 1975 deletions

View File

@@ -308,7 +308,6 @@ jobs:
"image": [ "'"$image_default"'" ],
"include": [{ "pg_version": 16, "region_id": "'"$region_id_default"'", "platform": "neonvm-captest-freetier", "db_size": "3gb" ,"runner": '"$runner_default"', "image": "'"$image_default"'" },
{ "pg_version": 16, "region_id": "'"$region_id_default"'", "platform": "neonvm-captest-new", "db_size": "10gb","runner": '"$runner_default"', "image": "'"$image_default"'" },
{ "pg_version": 16, "region_id": "'"$region_id_default"'", "platform": "neonvm-captest-new-many-tables","db_size": "10gb","runner": '"$runner_default"', "image": "'"$image_default"'" },
{ "pg_version": 16, "region_id": "'"$region_id_default"'", "platform": "neonvm-captest-new", "db_size": "50gb","runner": '"$runner_default"', "image": "'"$image_default"'" },
{ "pg_version": 16, "region_id": "azure-eastus2", "platform": "neonvm-azure-captest-freetier", "db_size": "3gb" ,"runner": '"$runner_azure"', "image": "neondatabase/build-tools:pinned-bookworm" },
{ "pg_version": 16, "region_id": "azure-eastus2", "platform": "neonvm-azure-captest-new", "db_size": "10gb","runner": '"$runner_azure"', "image": "neondatabase/build-tools:pinned-bookworm" },
@@ -411,7 +410,7 @@ jobs:
aws-oicd-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
if: contains(fromJson('["neonvm-captest-new", "neonvm-captest-new-many-tables", "neonvm-captest-freetier", "neonvm-azure-captest-freetier", "neonvm-azure-captest-new"]'), matrix.platform)
if: contains(fromJson('["neonvm-captest-new", "neonvm-captest-freetier", "neonvm-azure-captest-freetier", "neonvm-azure-captest-new"]'), matrix.platform)
id: create-neon-project
uses: ./.github/actions/neon-project-create
with:
@@ -430,7 +429,7 @@ jobs:
neonvm-captest-sharding-reuse)
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_SHARDING_CONNSTR }}
;;
neonvm-captest-new | neonvm-captest-new-many-tables | neonvm-captest-freetier | neonvm-azure-captest-new | neonvm-azure-captest-freetier)
neonvm-captest-new | neonvm-captest-freetier | neonvm-azure-captest-new | neonvm-azure-captest-freetier)
CONNSTR=${{ steps.create-neon-project.outputs.dsn }}
;;
rds-aurora)
@@ -447,26 +446,6 @@ jobs:
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
# we want to compare Neon project OLTP throughput and latency at scale factor 10 GB
# without (neonvm-captest-new)
# and with (neonvm-captest-new-many-tables) many relations in the database
- name: Create many relations before the run
if: contains(fromJson('["neonvm-captest-new-many-tables"]'), matrix.platform)
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}
test_selection: performance
run_in_parallel: false
save_perf_report: ${{ env.SAVE_PERF_REPORT }}
extra_params: -m remote_cluster --timeout 21600 -k test_perf_many_relations
pg_version: ${{ env.DEFAULT_PG_VERSION }}
aws-oicd-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
BENCHMARK_CONNSTR: ${{ steps.set-up-connstr.outputs.connstr }}
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
TEST_NUM_RELATIONS: 10000
- name: Benchmark init
uses: ./.github/actions/run-python-test-set
with:

View File

@@ -212,7 +212,7 @@ jobs:
fi
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
- name: Run cargo clippy (debug)
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
- name: Check documentation generation
run: cargo doc --workspace --no-deps --document-private-items

View File

@@ -21,8 +21,6 @@ concurrency:
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
jobs:
regress:

1
Cargo.lock generated
View File

@@ -1274,7 +1274,6 @@ dependencies = [
"chrono",
"clap",
"compute_api",
"fail",
"flate2",
"futures",
"hyper 0.14.30",

View File

@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
#########################################################################################
#
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
#########################################################################################
#
@@ -1338,6 +1338,20 @@ RUN set -e \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Compile the Neon-specific `local_proxy` binary
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
@@ -1477,7 +1491,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# local_proxy and its config
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
# Metrics exporter binaries and configuration files
@@ -1542,30 +1556,28 @@ RUN apt update && \
locales \
procps \
ca-certificates \
curl \
unzip \
$VERSION_INSTALLS && \
apt clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8
# aws cli is used by fast_import (curl and unzip above are at this time only used for this installation step)
# s5cmd 2.2.2 from https://github.com/peak/s5cmd/releases/tag/v2.2.2
# used by fast_import
ARG TARGETARCH
ADD https://github.com/peak/s5cmd/releases/download/v2.2.2/s5cmd_2.2.2_linux_$TARGETARCH.deb /tmp/s5cmd.deb
RUN set -ex; \
\
# Determine the expected checksum based on TARGETARCH
if [ "${TARGETARCH}" = "amd64" ]; then \
TARGETARCH_ALT="x86_64"; \
CHECKSUM="c9a9df3770a3ff9259cb469b6179e02829687a464e0824d5c32d378820b53a00"; \
CHECKSUM="392c385320cd5ffa435759a95af77c215553d967e4b1c0fffe52e4f14c29cf85"; \
elif [ "${TARGETARCH}" = "arm64" ]; then \
TARGETARCH_ALT="aarch64"; \
CHECKSUM="8181730be7891582b38b028112e81b4899ca817e8c616aad807c9e9d1289223a"; \
CHECKSUM="939bee3cf4b5604ddb00e67f8c157b91d7c7a5b553d1fbb6890fad32894b7b46"; \
else \
echo "Unsupported architecture: ${TARGETARCH}"; exit 1; \
fi; \
curl -L "https://awscli.amazonaws.com/awscli-exe-linux-${TARGETARCH_ALT}-2.17.5.zip" -o /tmp/awscliv2.zip; \
echo "${CHECKSUM} /tmp/awscliv2.zip" | sha256sum -c -; \
unzip /tmp/awscliv2.zip -d /tmp/awscliv2; \
/tmp/awscliv2/aws/install; \
rm -rf /tmp/awscliv2.zip /tmp/awscliv2; \
true
\
# Compute and validate the checksum
echo "${CHECKSUM} /tmp/s5cmd.deb" | sha256sum -c -
RUN dpkg -i /tmp/s5cmd.deb && rm /tmp/s5cmd.deb
ENV LANG=en_US.utf8
USER postgres

View File

@@ -7,7 +7,7 @@ license.workspace = true
[features]
default = []
# Enables test specific features.
testing = ["fail/failpoints"]
testing = []
[dependencies]
base64.workspace = true
@@ -19,7 +19,6 @@ camino.workspace = true
chrono.workspace = true
cfg-if.workspace = true
clap.workspace = true
fail.workspace = true
flate2.workspace = true
futures.workspace = true
hyper0 = { workspace = true, features = ["full"] }

View File

@@ -67,15 +67,12 @@ use compute_tools::params::*;
use compute_tools::spec::*;
use compute_tools::swap::resize_swap;
use rlimit::{setrlimit, Resource};
use utils::failpoint_support;
// this is an arbitrary build tag. Fine as a default / for testing purposes
// in-case of not-set environment var
const BUILD_TAG_DEFAULT: &str = "latest";
fn main() -> Result<()> {
let scenario = failpoint_support::init();
let (build_tag, clap_args) = init()?;
// enable core dumping for all child processes
@@ -103,8 +100,6 @@ fn main() -> Result<()> {
maybe_delay_exit(delay_exit);
scenario.teardown();
deinit_and_exit(wait_pg_result);
}
@@ -424,13 +419,9 @@ fn start_postgres(
"running compute with features: {:?}",
state.pspec.as_ref().unwrap().spec.features
);
// before we release the mutex, fetch some parameters for later.
let &ComputeSpec {
swap_size_bytes,
disk_quota_bytes,
disable_lfc_resizing,
..
} = &state.pspec.as_ref().unwrap().spec;
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -535,18 +526,11 @@ fn start_postgres(
// This token is used internally by the monitor to clean up all threads
let token = CancellationToken::new();
// don't pass postgres connection string to vm-monitor if we don't want it to resize LFC
let pgconnstr = if disable_lfc_resizing.unwrap_or(false) {
None
} else {
file_cache_connstr.cloned()
};
let vm_monitor = rt.as_ref().map(|rt| {
rt.spawn(vm_monitor::start(
Box::leak(Box::new(vm_monitor::Args {
cgroup: cgroup.cloned(),
pgconnstr,
pgconnstr: file_cache_connstr.cloned(),
addr: vm_monitor_addr.clone(),
})),
token.clone(),

View File

@@ -34,12 +34,12 @@ use nix::unistd::Pid;
use tracing::{info, info_span, warn, Instrument};
use utils::fs_ext::is_directory_empty;
#[path = "fast_import/aws_s3_sync.rs"]
mod aws_s3_sync;
#[path = "fast_import/child_stdio_to_log.rs"]
mod child_stdio_to_log;
#[path = "fast_import/s3_uri.rs"]
mod s3_uri;
#[path = "fast_import/s5cmd.rs"]
mod s5cmd;
#[derive(clap::Parser)]
struct Args {
@@ -326,7 +326,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
}
info!("upload pgdata");
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
s5cmd::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/"))
.await
.context("sync dump directory to destination")?;
@@ -334,10 +334,10 @@ pub(crate) async fn main() -> anyhow::Result<()> {
{
let status_dir = working_directory.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
let status_file = status_dir.join("pgdata");
let status_file = status_dir.join("status");
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
.context("write status file")?;
aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/"))
s5cmd::sync(&status_file, &s3_prefix.append("/status/pgdata"))
.await
.context("sync status directory to destination")?;
}

View File

@@ -4,21 +4,24 @@ use camino::Utf8Path;
use super::s3_uri::S3Uri;
pub(crate) async fn sync(local: &Utf8Path, remote: &S3Uri) -> anyhow::Result<()> {
let mut builder = tokio::process::Command::new("aws");
let mut builder = tokio::process::Command::new("s5cmd");
// s5cmd uses aws-sdk-go v1, hence doesn't support AWS_ENDPOINT_URL
if let Some(val) = std::env::var_os("AWS_ENDPOINT_URL") {
builder.arg("--endpoint-url").arg(val);
}
builder
.arg("s3")
.arg("sync")
.arg(local.as_str())
.arg(remote.to_string());
let st = builder
.spawn()
.context("spawn aws s3 sync")?
.context("spawn s5cmd")?
.wait()
.await
.context("wait for aws s3 sync")?;
.context("wait for s5cmd")?;
if st.success() {
Ok(())
} else {
Err(anyhow::anyhow!("aws s3 sync failed"))
Err(anyhow::anyhow!("s5cmd failed"))
}
}

View File

@@ -1181,19 +1181,8 @@ impl ComputeNode {
let mut conf = postgres::config::Config::from(conf);
conf.application_name("compute_ctl:migrations");
match conf.connect(NoTls) {
Ok(mut client) => {
if let Err(e) = handle_migrations(&mut client) {
error!("Failed to run migrations: {}", e);
}
}
Err(e) => {
error!(
"Failed to connect to the compute for running migrations: {}",
e
);
}
};
let mut client = conf.connect(NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
});
Ok::<(), anyhow::Error>(())

View File

@@ -24,11 +24,8 @@ use metrics::proto::MetricFamily;
use metrics::Encoder;
use metrics::TextEncoder;
use tokio::task;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing_utils::http::OtelName;
use utils::failpoint_support::failpoints_handler;
use utils::http::error::ApiError;
use utils::http::request::must_get_query_param;
fn status_response_from_state(state: &ComputeState) -> ComputeStatusResponse {
@@ -313,18 +310,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::POST, "/failpoints") if cfg!(feature = "testing") => {
match failpoints_handler(req, CancellationToken::new()).await {
Ok(r) => r,
Err(ApiError::BadRequest(e)) => {
render_json_error(&e.to_string(), StatusCode::BAD_REQUEST)
}
Err(_) => {
render_json_error("Internal server error", StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
// download extension files from remote extension storage on demand
(&Method::POST, route) if route.starts_with("/extension_server/") => {
info!("serving {:?} POST request", route);

View File

@@ -1,16 +1,13 @@
use anyhow::{Context, Result};
use fail::fail_point;
use postgres::Client;
use tracing::info;
/// Runs a series of migrations on a target database
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
@@ -18,7 +15,6 @@ impl<'m> MigrationRunner<'m> {
Self { client, migrations }
}
/// Get the current value neon_migration.migration_id
fn get_migration_id(&mut self) -> Result<i64> {
let query = "SELECT id FROM neon_migration.migration_id";
let row = self
@@ -29,61 +25,37 @@ impl<'m> MigrationRunner<'m> {
Ok(row.get::<&str, i64>("id"))
}
/// Update the neon_migration.migration_id value
///
/// This function has a fail point called compute-migration, which can be
/// used if you would like to fail the application of a series of migrations
/// at some point.
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
// We use this fail point in order to check that failing in the
// middle of applying a series of migrations fails in an expected
// manner
if cfg!(feature = "testing") {
let fail = (|| {
fail_point!("compute-migration", |fail_migration_id| {
migration_id == fail_migration_id.unwrap().parse::<i64>().unwrap()
});
false
})();
if fail {
return Err(anyhow::anyhow!(format!(
"migration {} was configured to fail because of a failpoint",
migration_id
)));
}
}
let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id);
self.client
.query(
"UPDATE neon_migration.migration_id SET id = $1",
&[&migration_id],
)
.simple_query(&setval)
.context("run_migrations update id")?;
Ok(())
}
/// Prepare the migrations the target database for handling migrations
fn prepare_database(&mut self) -> Result<()> {
self.client
.simple_query("CREATE SCHEMA IF NOT EXISTS neon_migration")?;
self.client.simple_query("CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)")?;
self.client.simple_query(
"INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING",
)?;
self.client
.simple_query("ALTER SCHEMA neon_migration OWNER TO cloud_admin")?;
self.client
.simple_query("REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC")?;
fn prepare_migrations(&mut self) -> Result<()> {
let query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
self.client.simple_query(query)?;
let query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
self.client.simple_query(query)?;
let query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
self.client.simple_query(query)?;
let query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
self.client.simple_query(query)?;
let query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
self.client.simple_query(query)?;
Ok(())
}
/// Run the configrured set of migrations
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_database()?;
self.prepare_migrations()?;
let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
@@ -97,11 +69,6 @@ impl<'m> MigrationRunner<'m> {
if migration.starts_with("-- SKIP") {
info!("Skipping migration id={}", migration_id!(current_migration));
// Even though we are skipping the migration, updating the
// migration ID should help keep logic easy to understand when
// trying to understand the state of a cluster.
self.update_migration_id(migration_id!(current_migration))?;
} else {
info!(
"Running migration id={}:\n{}\n",
@@ -120,6 +87,7 @@ impl<'m> MigrationRunner<'m> {
)
})?;
// Migration IDs start at 1
self.update_migration_id(migration_id!(current_migration))?;
self.client

View File

@@ -1,9 +0,0 @@
DO $$
DECLARE
bypassrls boolean;
BEGIN
SELECT rolbypassrls INTO bypassrls FROM pg_roles WHERE rolname = 'neon_superuser';
IF NOT bypassrls THEN
RAISE EXCEPTION 'neon_superuser cannot bypass RLS';
END IF;
END $$;

View File

@@ -1,25 +0,0 @@
DO $$
DECLARE
role record;
BEGIN
FOR role IN
SELECT rolname AS name, rolinherit AS inherit
FROM pg_roles
WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
IF NOT role.inherit THEN
RAISE EXCEPTION '% cannot inherit', quote_ident(role.name);
END IF;
END LOOP;
FOR role IN
SELECT rolname AS name, rolbypassrls AS bypassrls
FROM pg_roles
WHERE NOT pg_has_role(rolname, 'neon_superuser', 'member')
AND NOT starts_with(rolname, 'pg_')
LOOP
IF role.bypassrls THEN
RAISE EXCEPTION '% can bypass RLS', quote_ident(role.name);
END IF;
END LOOP;
END $$;

View File

@@ -1,10 +0,0 @@
DO $$
BEGIN
IF (SELECT current_setting('server_version_num')::numeric < 160000) THEN
RETURN;
END IF;
IF NOT (SELECT pg_has_role('neon_superuser', 'pg_create_subscription', 'member')) THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_create_subscription';
END IF;
END $$;

View File

@@ -1,19 +0,0 @@
DO $$
DECLARE
monitor record;
BEGIN
SELECT pg_has_role('neon_superuser', 'pg_monitor', 'member') AS member,
admin_option AS admin
INTO monitor
FROM pg_auth_members
WHERE roleid = 'pg_monitor'::regrole
AND member = 'pg_monitor'::regrole;
IF NOT monitor.member THEN
RAISE EXCEPTION 'neon_superuser is not a member of pg_monitor';
END IF;
IF NOT monitor.admin THEN
RAISE EXCEPTION 'neon_superuser cannot grant pg_monitor';
END IF;
END $$;

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT bool_and(has_function_privilege('neon_superuser', oid, 'execute'))
INTO can_execute
FROM pg_proc
WHERE proname IN ('pg_export_snapshot', 'pg_log_standby_snapshot')
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute both pg_export_snapshot and pg_log_standby_snapshot';
END IF;
END $$;

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT has_function_privilege('neon_superuser', oid, 'execute')
INTO can_execute
FROM pg_proc
WHERE proname = 'pg_show_replication_origin_status'
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_show_replication_origin_status';
END IF;
END $$;

View File

@@ -19,7 +19,6 @@ use control_plane::storage_controller::{
NeonStorageControllerStartArgs, NeonStorageControllerStopArgs, StorageController,
};
use control_plane::{broker, local_env};
use nix::fcntl::{flock, FlockArg};
use pageserver_api::config::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT,
@@ -37,8 +36,6 @@ use safekeeper_api::{
};
use std::borrow::Cow;
use std::collections::{BTreeSet, HashMap};
use std::fs::File;
use std::os::fd::AsRawFd;
use std::path::PathBuf;
use std::process::exit;
use std::str::FromStr;
@@ -692,21 +689,6 @@ struct TimelineTreeEl {
pub children: BTreeSet<TimelineId>,
}
/// A flock-based guard over the neon_local repository directory
struct RepoLock {
_file: File,
}
impl RepoLock {
fn new() -> Result<Self> {
let repo_dir = File::open(local_env::base_path())?;
let repo_dir_fd = repo_dir.as_raw_fd();
flock(repo_dir_fd, FlockArg::LockExclusive)?;
Ok(Self { _file: repo_dir })
}
}
// Main entry point for the 'neon_local' CLI utility
//
// This utility helps to manage neon installation. That includes following:
@@ -718,14 +700,9 @@ fn main() -> Result<()> {
let cli = Cli::parse();
// Check for 'neon init' command first.
let (subcommand_result, _lock) = if let NeonLocalCmd::Init(args) = cli.command {
(handle_init(&args).map(|env| Some(Cow::Owned(env))), None)
let subcommand_result = if let NeonLocalCmd::Init(args) = cli.command {
handle_init(&args).map(|env| Some(Cow::Owned(env)))
} else {
// This tool uses a collection of simple files to store its state, and consequently
// it is not generally safe to run multiple commands concurrently. Rather than expect
// all callers to know this, use a lock file to protect against concurrent execution.
let _repo_lock = RepoLock::new().unwrap();
// all other commands need an existing config
let env = LocalEnv::load_config(&local_env::base_path()).context("Error loading config")?;
let original_env = env.clone();
@@ -751,12 +728,11 @@ fn main() -> Result<()> {
NeonLocalCmd::Mappings(subcmd) => handle_mappings(&subcmd, env),
};
let subcommand_result = if &original_env != env {
if &original_env != env {
subcommand_result.map(|()| Some(Cow::Borrowed(env)))
} else {
subcommand_result.map(|()| None)
};
(subcommand_result, Some(_repo_lock))
}
};
match subcommand_result {
@@ -946,7 +922,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
} else {
// User (likely interactive) did not provide a description of the environment, give them the default
NeonLocalInitConf {
control_plane_api: Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap()),
control_plane_api: Some(Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap())),
broker: NeonBroker {
listen_addr: DEFAULT_BROKER_ADDR.parse().unwrap(),
},
@@ -1742,15 +1718,18 @@ async fn handle_start_all_impl(
broker::start_broker_process(env, &retry_timeout).await
});
js.spawn(async move {
let storage_controller = StorageController::from_env(env);
storage_controller
.start(NeonStorageControllerStartArgs::with_default_instance_id(
retry_timeout,
))
.await
.map_err(|e| e.context("start storage_controller"))
});
// Only start the storage controller if the pageserver is configured to need it
if env.control_plane_api.is_some() {
js.spawn(async move {
let storage_controller = StorageController::from_env(env);
storage_controller
.start(NeonStorageControllerStartArgs::with_default_instance_id(
retry_timeout,
))
.await
.map_err(|e| e.context("start storage_controller"))
});
}
for ps_conf in &env.pageservers {
js.spawn(async move {
@@ -1795,6 +1774,10 @@ async fn neon_start_status_check(
const RETRY_INTERVAL: Duration = Duration::from_millis(100);
const NOTICE_AFTER_RETRIES: Duration = Duration::from_secs(5);
if env.control_plane_api.is_none() {
return Ok(());
}
let storcon = StorageController::from_env(env);
let retries = retry_timeout.as_millis() / RETRY_INTERVAL.as_millis();

View File

@@ -316,10 +316,6 @@ impl Endpoint {
// and can cause errors like 'no unpinned buffers available', see
// <https://github.com/neondatabase/neon/issues/9956>
conf.append("shared_buffers", "1MB");
// Postgres defaults to effective_io_concurrency=1, which does not exercise the pageserver's
// batching logic. Set this to 2 so that we exercise the code a bit without letting
// individual tests do a lot of concurrent work on underpowered test machines
conf.append("effective_io_concurrency", "2");
conf.append("fsync", "off");
conf.append("max_connections", "100");
conf.append("wal_level", "logical");
@@ -585,7 +581,6 @@ impl Endpoint {
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
disable_lfc_resizing: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -76,7 +76,7 @@ pub struct LocalEnv {
// Control plane upcall API for pageserver: if None, we will not run storage_controller If set, this will
// be propagated into each pageserver's configuration.
pub control_plane_api: Url,
pub control_plane_api: Option<Url>,
// Control plane upcall API for storage controller. If set, this will be propagated into the
// storage controller's configuration.
@@ -133,7 +133,7 @@ pub struct NeonLocalInitConf {
pub storage_controller: Option<NeonStorageControllerConf>,
pub pageservers: Vec<NeonLocalInitPageserverConf>,
pub safekeepers: Vec<SafekeeperConf>,
pub control_plane_api: Option<Url>,
pub control_plane_api: Option<Option<Url>>,
pub control_plane_compute_hook_api: Option<Option<Url>>,
}
@@ -180,7 +180,7 @@ impl NeonStorageControllerConf {
const DEFAULT_MAX_WARMING_UP_INTERVAL: std::time::Duration = std::time::Duration::from_secs(30);
// Very tight heartbeat interval to speed up tests
const DEFAULT_HEARTBEAT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(1000);
const DEFAULT_HEARTBEAT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(100);
}
impl Default for NeonStorageControllerConf {
@@ -535,7 +535,7 @@ impl LocalEnv {
storage_controller,
pageservers,
safekeepers,
control_plane_api: control_plane_api.unwrap(),
control_plane_api,
control_plane_compute_hook_api,
branch_name_mappings,
}
@@ -638,7 +638,7 @@ impl LocalEnv {
storage_controller: self.storage_controller.clone(),
pageservers: vec![], // it's skip_serializing anyway
safekeepers: self.safekeepers.clone(),
control_plane_api: Some(self.control_plane_api.clone()),
control_plane_api: self.control_plane_api.clone(),
control_plane_compute_hook_api: self.control_plane_compute_hook_api.clone(),
branch_name_mappings: self.branch_name_mappings.clone(),
},
@@ -768,7 +768,7 @@ impl LocalEnv {
storage_controller: storage_controller.unwrap_or_default(),
pageservers: pageservers.iter().map(Into::into).collect(),
safekeepers,
control_plane_api: control_plane_api.unwrap(),
control_plane_api: control_plane_api.unwrap_or_default(),
control_plane_compute_hook_api: control_plane_compute_hook_api.unwrap_or_default(),
branch_name_mappings: Default::default(),
};

View File

@@ -95,19 +95,21 @@ impl PageServerNode {
let mut overrides = vec![pg_distrib_dir_param, broker_endpoint_param];
overrides.push(format!(
"control_plane_api='{}'",
self.env.control_plane_api.as_str()
));
if let Some(control_plane_api) = &self.env.control_plane_api {
overrides.push(format!(
"control_plane_api='{}'",
control_plane_api.as_str()
));
// Storage controller uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
.unwrap();
overrides.push(format!("control_plane_api_token='{}'", jwt_token));
// Storage controller uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
.unwrap();
overrides.push(format!("control_plane_api_token='{}'", jwt_token));
}
}
if !conf.other.contains_key("remote_storage") {

View File

@@ -338,7 +338,7 @@ impl StorageController {
.port(),
)
} else {
let listen_url = self.env.control_plane_api.clone();
let listen_url = self.env.control_plane_api.clone().unwrap();
let listen = format!(
"{}:{}",
@@ -708,7 +708,7 @@ impl StorageController {
} else {
// The configured URL has the /upcall path prefix for pageservers to use: we will strip that out
// for general purpose API access.
let listen_url = self.env.control_plane_api.clone();
let listen_url = self.env.control_plane_api.clone().unwrap();
Url::from_str(&format!(
"http://{}:{}/{path}",
listen_url.host_str().unwrap(),

View File

@@ -5,8 +5,7 @@ use clap::{Parser, Subcommand};
use pageserver_api::{
controller_api::{
AvailabilityZone, NodeAvailabilityWrapper, NodeDescribeResponse, NodeShardResponse,
SafekeeperDescribeResponse, ShardSchedulingPolicy, TenantCreateRequest,
TenantDescribeResponse, TenantPolicyRequest,
ShardSchedulingPolicy, TenantCreateRequest, TenantDescribeResponse, TenantPolicyRequest,
},
models::{
EvictionPolicy, EvictionPolicyLayerAccessThreshold, LocationConfigSecondary,
@@ -212,8 +211,6 @@ enum Command {
#[arg(long)]
timeout: humantime::Duration,
},
/// List safekeepers known to the storage controller
Safekeepers {},
}
#[derive(Parser)]
@@ -1023,31 +1020,6 @@ async fn main() -> anyhow::Result<()> {
"Fill was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}"
);
}
Command::Safekeepers {} => {
let mut resp = storcon_client
.dispatch::<(), Vec<SafekeeperDescribeResponse>>(
Method::GET,
"control/v1/safekeeper".to_string(),
None,
)
.await?;
resp.sort_by(|a, b| a.id.cmp(&b.id));
let mut table = comfy_table::Table::new();
table.set_header(["Id", "Version", "Host", "Port", "Http Port", "AZ Id"]);
for sk in resp {
table.add_row([
format!("{}", sk.id),
format!("{}", sk.version),
sk.host,
format!("{}", sk.port),
format!("{}", sk.http_port),
sk.availability_zone_id.to_string(),
]);
}
println!("{table}");
}
}
Ok(())

View File

@@ -67,15 +67,6 @@ pub struct ComputeSpec {
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Disables the vm-monitor behavior that resizes LFC on upscale/downscale, instead relying on
/// the initial size of LFC.
///
/// This is intended for use when the LFC size is being overridden from the default but
/// autoscaling is still enabled, and we don't want the vm-monitor to interfere with the custom
/// LFC sizing.
#[serde(default)]
pub disable_lfc_resizing: Option<bool>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,

View File

@@ -372,23 +372,6 @@ pub struct MetadataHealthListOutdatedResponse {
pub health_records: Vec<MetadataHealthRecord>,
}
/// Publicly exposed safekeeper description
///
/// The `active` flag which we have in the DB is not included on purpose: it is deprecated.
#[derive(Serialize, Deserialize, Clone)]
pub struct SafekeeperDescribeResponse {
pub id: NodeId,
pub region_id: String,
/// 1 is special, it means just created (not currently posted to storcon).
/// Zero or negative is not really expected.
/// Otherwise the number from `release-$(number_of_commits_on_branch)` tag.
pub version: i64,
pub host: String,
pub port: i32,
pub http_port: i32,
pub availability_zone_id: String,
}
#[cfg(test)]
mod test {
use super::*;

View File

@@ -6,7 +6,6 @@ pub mod utilization;
use camino::Utf8PathBuf;
pub use utilization::PageserverUtilization;
use core::ops::Range;
use std::{
collections::HashMap,
fmt::Display,
@@ -29,7 +28,6 @@ use utils::{
};
use crate::{
key::Key,
reltag::RelTag,
shard::{ShardCount, ShardStripeSize, TenantShardId},
};
@@ -212,68 +210,6 @@ pub enum TimelineState {
Broken { reason: String, backtrace: String },
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct CompactLsnRange {
pub start: Lsn,
pub end: Lsn,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct CompactKeyRange {
#[serde_as(as = "serde_with::DisplayFromStr")]
pub start: Key,
#[serde_as(as = "serde_with::DisplayFromStr")]
pub end: Key,
}
impl From<Range<Lsn>> for CompactLsnRange {
fn from(range: Range<Lsn>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
}
impl From<Range<Key>> for CompactKeyRange {
fn from(range: Range<Key>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
}
impl From<CompactLsnRange> for Range<Lsn> {
fn from(range: CompactLsnRange) -> Self {
range.start..range.end
}
}
impl From<CompactKeyRange> for Range<Key> {
fn from(range: CompactKeyRange) -> Self {
range.start..range.end
}
}
impl CompactLsnRange {
pub fn above(lsn: Lsn) -> Self {
Self {
start: lsn,
end: Lsn::MAX,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct CompactInfoResponse {
pub compact_key_range: Option<CompactKeyRange>,
pub compact_lsn_range: Option<CompactLsnRange>,
pub sub_compaction: bool,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct TimelineCreateRequest {
pub new_timeline_id: TimelineId,

View File

@@ -106,11 +106,11 @@ impl<R: RecordGenerator> WalGenerator<R> {
const TIMELINE_ID: u32 = 1;
/// Creates a new WAL generator with the given record generator.
pub fn new(record_generator: R, start_lsn: Lsn) -> WalGenerator<R> {
pub fn new(record_generator: R) -> WalGenerator<R> {
Self {
record_generator,
lsn: start_lsn,
prev_lsn: start_lsn,
lsn: Lsn(0),
prev_lsn: Lsn(0),
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "postgres-protocol2"
version = "0.1.0"
edition = "2021"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]

View File

@@ -9,7 +9,8 @@
//!
//! This library assumes that the `client_encoding` backend parameter has been
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![warn(missing_docs, clippy::all)]
#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")]
#![warn(missing_docs, rust_2018_idioms, clippy::all)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};

View File

@@ -3,6 +3,7 @@
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;

View File

@@ -1,7 +1,7 @@
[package]
name = "postgres-types2"
version = "0.1.0"
edition = "2021"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]

View File

@@ -2,7 +2,8 @@
//!
//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it
//! unless you want to define your own `ToSql` or `FromSql` definitions.
#![warn(clippy::all, missing_docs)]
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
use fallible_iterator::FallibleIterator;
use postgres_protocol2::types;

View File

@@ -1,7 +1,7 @@
[package]
name = "tokio-postgres2"
version = "0.1.0"
edition = "2021"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]

View File

@@ -9,7 +9,7 @@ use std::io;
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>

View File

@@ -10,7 +10,7 @@ use tokio::net::TcpStream;
use tokio::sync::mpsc;
pub async fn connect<T>(
mut tls: T,
tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where

View File

@@ -33,12 +33,8 @@ pub struct Response {
#[derive(PartialEq, Debug)]
enum State {
Active,
Closing,
}
enum WriteReady {
Terminating,
WaitingOnRead,
Closing,
}
/// A connection to a PostgreSQL database.
@@ -55,6 +51,7 @@ pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
@@ -75,6 +72,7 @@ where
stream,
parameters,
receiver,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
@@ -95,23 +93,26 @@ where
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
@@ -119,7 +120,7 @@ where
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
@@ -138,10 +139,8 @@ where
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => {
return Poll::Ready(Err(Error::db(error)))
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
@@ -165,14 +164,18 @@ where
request_complete,
});
trace!("poll_read: waiting on sender");
return Poll::Pending;
return Ok(None);
}
}
}
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
@@ -190,80 +193,74 @@ where
}
}
/// Process client requests and write them to the postgres connection, flushing if necessary.
/// client -> postgres
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
// poll_ready is self-flushing.
return Poll::Pending;
return Ok(false);
}
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(RequestMessages::Single(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) if self.responses.is_empty() => {
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
trace!("poll_write: done");
return Poll::Ready(Ok(WriteReady::Terminating));
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
// No more messages from the client, but there are still some responses to wait for.
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
ready!(self.poll_flush(cx))?;
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
return Ok(true);
}
// Still waiting for a message from the client.
Poll::Pending => {
trace!("poll_write: waiting on request");
ready!(self.poll_flush(cx))?;
return Poll::Pending;
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_flush: flushed");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_flush: waiting on socket");
Poll::Pending
}
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
@@ -292,30 +289,18 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(WriteReady::Terminating) = closing {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}

View File

@@ -1,5 +1,5 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(clippy::all)]
#![warn(rust_2018_idioms, clippy::all)]
pub use crate::cancel_token::CancelToken;
pub use crate::client::{Client, SocketConfig};

View File

@@ -46,7 +46,7 @@ pub trait MakeTlsConnect<S> {
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect(self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -84,7 +84,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::*;
/// Declare a failpoint that can use to `pause` failpoint action.
/// Declare a failpoint that can use the `pause` failpoint action.
/// We don't want to block the executor thread, hence, spawn_blocking + await.
#[macro_export]
macro_rules! pausable_failpoint {
@@ -181,7 +181,7 @@ pub async fn failpoints_handler(
) -> Result<Response<Body>, ApiError> {
if !fail::has_failpoints() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Cannot manage failpoints because neon was compiled without failpoints support"
"Cannot manage failpoints because storage was compiled without failpoints support"
)));
}

View File

@@ -53,12 +53,10 @@ project_build_tag!(BUILD_TAG);
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
// Configure jemalloc to sample allocations for profiles every 1 MB (1 << 20).
// TODO: disabled because concurrent CPU profiles cause seg faults. See:
// https://github.com/neondatabase/neon/issues/10225.
//#[allow(non_upper_case_globals)]
//#[export_name = "malloc_conf"]
//pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:20\0";
/// Configure jemalloc to sample allocations for profiles every 1 MB (1 << 20).
#[allow(non_upper_case_globals)]
#[export_name = "malloc_conf"]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:20\0";
const PID_FILE_NAME: &str = "pageserver.pid";

View File

@@ -97,8 +97,8 @@ use crate::tenant::{LogicalSizeCalculationCause, PageReconstructError};
use crate::DEFAULT_PG_VERSION;
use crate::{disk_usage_eviction_task, tenant};
use pageserver_api::models::{
CompactInfoResponse, StatusResponse, TenantConfigRequest, TenantInfo, TimelineCreateRequest,
TimelineGcRequest, TimelineInfo,
StatusResponse, TenantConfigRequest, TenantInfo, TimelineCreateRequest, TimelineGcRequest,
TimelineInfo,
};
use utils::{
auth::SwappableJwtAuth,
@@ -2039,34 +2039,6 @@ async fn timeline_cancel_compact_handler(
.await
}
// Get compact info of a timeline
async fn timeline_compact_info_handler(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let state = get_state(&request);
async {
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let res = tenant.get_scheduled_compaction_tasks(timeline_id);
let mut resp = Vec::new();
for item in res {
resp.push(CompactInfoResponse {
compact_key_range: item.compact_key_range,
compact_lsn_range: item.compact_lsn_range,
sub_compaction: item.sub_compaction,
});
}
json_response(StatusCode::OK, resp)
}
.instrument(info_span!("timeline_compact_info", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %timeline_id))
.await
}
// Run compaction immediately on given timeline.
async fn timeline_compact_handler(
mut request: Request<Body>,
@@ -3428,10 +3400,6 @@ pub fn make_router(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/do_gc",
|r| api_handler(r, timeline_gc_handler),
)
.get(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/compact",
|r| api_handler(r, timeline_compact_info_handler),
)
.put(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/compact",
|r| api_handler(r, timeline_compact_handler),

View File

@@ -3,7 +3,7 @@ use metrics::{
register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec,
register_int_counter, register_int_counter_pair_vec, register_int_counter_vec,
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
Counter, CounterVec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
Counter, CounterVec, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, UIntGauge, UIntGaugeVec,
};
use once_cell::sync::Lazy;
@@ -445,15 +445,6 @@ pub(crate) static WAIT_LSN_TIME: Lazy<Histogram> = Lazy::new(|| {
.expect("failed to define a metric")
});
static FLUSH_WAIT_UPLOAD_TIME: Lazy<GaugeVec> = Lazy::new(|| {
register_gauge_vec!(
"pageserver_flush_wait_upload_seconds",
"Time spent waiting for preceding uploads during layer flush",
&["tenant_id", "shard_id", "timeline_id"]
)
.expect("failed to define a metric")
});
static LAST_RECORD_LSN: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!(
"pageserver_last_record_lsn",
@@ -2586,7 +2577,6 @@ pub(crate) struct TimelineMetrics {
shard_id: String,
timeline_id: String,
pub flush_time_histo: StorageTimeMetrics,
pub flush_wait_upload_time_gauge: Gauge,
pub compact_time_histo: StorageTimeMetrics,
pub create_images_time_histo: StorageTimeMetrics,
pub logical_size_histo: StorageTimeMetrics,
@@ -2632,9 +2622,6 @@ impl TimelineMetrics {
&shard_id,
&timeline_id,
);
let flush_wait_upload_time_gauge = FLUSH_WAIT_UPLOAD_TIME
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
let compact_time_histo = StorageTimeMetrics::new(
StorageTimeOperation::Compact,
&tenant_id,
@@ -2780,7 +2767,6 @@ impl TimelineMetrics {
shard_id,
timeline_id,
flush_time_histo,
flush_wait_upload_time_gauge,
compact_time_histo,
create_images_time_histo,
logical_size_histo,
@@ -2830,14 +2816,6 @@ impl TimelineMetrics {
self.resident_physical_size_gauge.get()
}
pub(crate) fn flush_wait_upload_time_gauge_add(&self, duration: f64) {
self.flush_wait_upload_time_gauge.add(duration);
crate::metrics::FLUSH_WAIT_UPLOAD_TIME
.get_metric_with_label_values(&[&self.tenant_id, &self.shard_id, &self.timeline_id])
.unwrap()
.add(duration);
}
pub(crate) fn shutdown(&self) {
let was_shutdown = self
.shutdown
@@ -2855,7 +2833,6 @@ impl TimelineMetrics {
let shard_id = &self.shard_id;
let _ = LAST_RECORD_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]);
let _ = DISK_CONSISTENT_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]);
let _ = FLUSH_WAIT_UPLOAD_TIME.remove_label_values(&[tenant_id, shard_id, timeline_id]);
let _ = STANDBY_HORIZON.remove_label_values(&[tenant_id, shard_id, timeline_id]);
{
RESIDENT_PHYSICAL_SIZE_GLOBAL.sub(self.resident_physical_size_get());

View File

@@ -3122,23 +3122,6 @@ impl Tenant {
}
}
pub(crate) fn get_scheduled_compaction_tasks(
&self,
timeline_id: TimelineId,
) -> Vec<CompactOptions> {
use itertools::Itertools;
let guard = self.scheduled_compaction_tasks.lock().unwrap();
guard
.get(&timeline_id)
.map(|tline_pending_tasks| {
tline_pending_tasks
.iter()
.map(|x| x.options.clone())
.collect_vec()
})
.unwrap_or_default()
}
/// Schedule a compaction task for a timeline.
pub(crate) async fn schedule_compaction(
&self,
@@ -5776,13 +5759,13 @@ mod tests {
use timeline::{CompactOptions, DeltaLayerTestDesc};
use utils::id::TenantId;
#[cfg(feature = "testing")]
use models::CompactLsnRange;
#[cfg(feature = "testing")]
use pageserver_api::record::NeonWalRecord;
#[cfg(feature = "testing")]
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
#[cfg(feature = "testing")]
use timeline::CompactLsnRange;
#[cfg(feature = "testing")]
use timeline::GcInfo;
static TEST_KEY: Lazy<Key> =
@@ -9651,7 +9634,7 @@ mod tests {
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_simple_bottom_most_compaction_on_branch() -> anyhow::Result<()> {
use models::CompactLsnRange;
use timeline::CompactLsnRange;
let harness = TenantHarness::create("test_simple_bottom_most_compaction_on_branch").await?;
let (tenant, ctx) = harness.load().await;

View File

@@ -1,15 +1,12 @@
use std::collections::BTreeSet;
use itertools::Itertools;
use pageserver_compaction::helpers::overlaps_with;
use super::storage_layer::LayerName;
/// Checks whether a layer map is valid (i.e., is a valid result of the current compaction algorithm if nothing goes wrong).
///
/// The function implements a fast path check and a slow path check.
///
/// The fast path checks if we can split the LSN range of a delta layer only at the LSNs of the delta layers. For example,
/// The function checks if we can split the LSN range of a delta layer only at the LSNs of the delta layers. For example,
///
/// ```plain
/// | | | |
@@ -28,47 +25,31 @@ use super::storage_layer::LayerName;
/// | | | 4 | | |
///
/// If layer 2 and 4 contain the same single key, this is also a valid layer map.
///
/// However, if a partial compaction is still going on, it is possible that we get a layer map not satisfying the above condition.
/// Therefore, we fallback to simply check if any of the two delta layers overlap. (See "A slow path...")
pub fn check_valid_layermap(metadata: &[LayerName]) -> Option<String> {
let mut lsn_split_point = BTreeSet::new(); // TODO: use a better data structure (range tree / range set?)
let mut all_delta_layers = Vec::new();
for name in metadata {
if let LayerName::Delta(layer) = name {
all_delta_layers.push(layer.clone());
if layer.key_range.start.next() != layer.key_range.end {
all_delta_layers.push(layer.clone());
}
}
}
for layer in &all_delta_layers {
if layer.key_range.start.next() != layer.key_range.end {
let lsn_range = &layer.lsn_range;
lsn_split_point.insert(lsn_range.start);
lsn_split_point.insert(lsn_range.end);
}
let lsn_range = &layer.lsn_range;
lsn_split_point.insert(lsn_range.start);
lsn_split_point.insert(lsn_range.end);
}
for (idx, layer) in all_delta_layers.iter().enumerate() {
if layer.key_range.start.next() == layer.key_range.end {
continue;
}
for layer in &all_delta_layers {
let lsn_range = layer.lsn_range.clone();
let intersects = lsn_split_point.range(lsn_range).collect_vec();
if intersects.len() > 1 {
// A slow path to check if the layer intersects with any other delta layer.
for (other_idx, other_layer) in all_delta_layers.iter().enumerate() {
if other_idx == idx {
// do not check self intersects with self
continue;
}
if overlaps_with(&layer.lsn_range, &other_layer.lsn_range)
&& overlaps_with(&layer.key_range, &other_layer.key_range)
{
let err = format!(
"layer violates the layer map LSN split assumption: layer {} intersects with layer {}",
layer, other_layer
);
return Some(err);
}
}
let err = format!(
"layer violates the layer map LSN split assumption: layer {} intersects with LSN [{}]",
layer,
intersects.into_iter().map(|lsn| lsn.to_string()).join(", ")
);
return Some(err);
}
}
None

View File

@@ -31,9 +31,9 @@ use pageserver_api::{
},
keyspace::{KeySpaceAccum, KeySpaceRandomAccum, SparseKeyPartitioning},
models::{
CompactKeyRange, CompactLsnRange, CompactionAlgorithm, CompactionAlgorithmSettings,
DownloadRemoteLayersTaskInfo, DownloadRemoteLayersTaskSpawnRequest, EvictionPolicy,
InMemoryLayerInfo, LayerMapInfo, LsnLease, TimelineState,
CompactionAlgorithm, CompactionAlgorithmSettings, DownloadRemoteLayersTaskInfo,
DownloadRemoteLayersTaskSpawnRequest, EvictionPolicy, InMemoryLayerInfo, LayerMapInfo,
LsnLease, TimelineState,
},
reltag::BlockNumber,
shard::{ShardIdentity, ShardNumber, TenantShardId},
@@ -144,19 +144,15 @@ use self::layer_manager::LayerManager;
use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::config::TenantConf;
use super::remote_timeline_client::index::IndexPart;
use super::remote_timeline_client::RemoteTimelineClient;
use super::secondary::heatmap::{HeatMapLayer, HeatMapTimeline};
use super::storage_layer::{LayerFringe, LayerVisibilityHint, ReadableLayer};
use super::upload_queue::NotInitialized;
use super::GcError;
use super::{
config::TenantConf, storage_layer::LayerVisibilityHint, upload_queue::NotInitialized,
MaybeOffloaded,
};
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
use super::{remote_timeline_client::index::IndexPart, storage_layer::LayerFringe};
use super::{
remote_timeline_client::RemoteTimelineClient, remote_timeline_client::WaitCompletionError,
storage_layer::ReadableLayer,
};
use super::{
secondary::heatmap::{HeatMapLayer, HeatMapTimeline},
GcError,
debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf, MaybeOffloaded,
};
#[cfg(test)]
@@ -792,6 +788,63 @@ pub(crate) struct CompactRequest {
pub sub_compaction_max_job_size_mb: Option<u64>,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct CompactLsnRange {
pub start: Lsn,
pub end: Lsn,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct CompactKeyRange {
#[serde_as(as = "serde_with::DisplayFromStr")]
pub start: Key,
#[serde_as(as = "serde_with::DisplayFromStr")]
pub end: Key,
}
impl From<Range<Lsn>> for CompactLsnRange {
fn from(range: Range<Lsn>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
}
impl From<Range<Key>> for CompactKeyRange {
fn from(range: Range<Key>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
}
impl From<CompactLsnRange> for Range<Lsn> {
fn from(range: CompactLsnRange) -> Self {
range.start..range.end
}
}
impl From<CompactKeyRange> for Range<Key> {
fn from(range: CompactKeyRange) -> Self {
range.start..range.end
}
}
impl CompactLsnRange {
#[cfg(test)]
#[cfg(feature = "testing")]
pub fn above(lsn: Lsn) -> Self {
Self {
start: lsn,
end: Lsn::MAX,
}
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct CompactOptions {
pub flags: EnumSet<CompactFlags>,
@@ -3840,24 +3893,6 @@ impl Timeline {
// release lock on 'layers'
};
// Backpressure mechanism: wait with continuation of the flush loop until we have uploaded all layer files.
// This makes us refuse ingest until the new layers have been persisted to the remote
let start = Instant::now();
self.remote_client
.wait_completion()
.await
.map_err(|e| match e {
WaitCompletionError::UploadQueueShutDownOrStopped
| WaitCompletionError::NotInitialized(
NotInitialized::ShuttingDown | NotInitialized::Stopped,
) => FlushLayerError::Cancelled,
WaitCompletionError::NotInitialized(NotInitialized::Uninitialized) => {
FlushLayerError::Other(anyhow!(e).into())
}
})?;
let duration = start.elapsed().as_secs_f64();
self.metrics.flush_wait_upload_time_gauge_add(duration);
// FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`,
// a compaction can delete the file and then it won't be available for uploads any more.
// We still schedule the upload, resulting in an error, but ideally we'd somehow avoid this

View File

@@ -29,7 +29,6 @@ use utils::id::TimelineId;
use crate::context::{AccessStatsBehavior, RequestContext, RequestContextBuilder};
use crate::page_cache;
use crate::statvfs::Statvfs;
use crate::tenant::checks::check_valid_layermap;
use crate::tenant::remote_timeline_client::WaitCompletionError;
use crate::tenant::storage_layer::batch_split_writer::{
BatchWriterResult, SplitDeltaLayerWriter, SplitImageLayerWriter,
@@ -1824,7 +1823,7 @@ impl Timeline {
// by estimating the amount of files read for a compaction job. We should also partition on LSN.
let ((dense_ks, sparse_ks), _) = {
let Ok(partition) = self.partitioning.try_lock() else {
bail!("failed to acquire partition lock during gc-compaction");
bail!("failed to acquire partition lock");
};
partition.clone()
};
@@ -2157,14 +2156,15 @@ impl Timeline {
// Step 1: construct a k-merge iterator over all layers.
// Also, verify if the layer map can be split by drawing a horizontal line at every LSN start/end split point.
let layer_names = job_desc
.selected_layers
.iter()
.map(|layer| layer.layer_desc().layer_name())
.collect_vec();
if let Some(err) = check_valid_layermap(&layer_names) {
bail!("gc-compaction layer map check failed because {}, cannot proceed with compaction due to potential data loss", err);
}
// disable the check for now because we need to adjust the check for partial compactions, will enable later.
// let layer_names = job_desc
// .selected_layers
// .iter()
// .map(|layer| layer.layer_desc().layer_name())
// .collect_vec();
// if let Some(err) = check_valid_layermap(&layer_names) {
// warn!("gc-compaction layer map check failed because {}, this is normal if partial compaction is not finished yet", err);
// }
// The maximum LSN we are processing in this compaction loop
let end_lsn = job_desc
.selected_layers
@@ -2546,48 +2546,13 @@ impl Timeline {
);
// Step 3: Place back to the layer map.
// First, do a sanity check to ensure the newly-created layer map does not contain overlaps.
let all_layers = {
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
let mut final_layers = all_layers
.iter()
.map(|layer| layer.layer_name())
.collect::<HashSet<_>>();
for layer in &layer_selection {
final_layers.remove(&layer.layer_desc().layer_name());
}
for layer in &compact_to {
final_layers.insert(layer.layer_desc().layer_name());
}
let final_layers = final_layers.into_iter().collect_vec();
// TODO: move this check before we call `finish` on image layer writers. However, this will require us to get the layer name before we finish
// the writer, so potentially, we will need a function like `ImageLayerBatchWriter::get_all_pending_layer_keys` to get all the keys that are
// in the writer before finalizing the persistent layers. Now we would leave some dangling layers on the disk if the check fails.
if let Some(err) = check_valid_layermap(&final_layers) {
bail!("gc-compaction layer map check failed after compaction because {}, compaction result not applied to the layer map due to potential data loss", err);
}
// Between the sanity check and this compaction update, there could be new layers being flushed, but it should be fine because we only
// operate on L1 layers.
{
// TODO: sanity check if the layer map is valid (i.e., should not have overlaps)
let mut guard = self.layers.write().await;
guard
.open_mut()?
.finish_gc_compaction(&layer_selection, &compact_to, &self.metrics)
};
// Schedule an index-only upload to update the `latest_gc_cutoff` in the index_part.json.
// Otherwise, after restart, the index_part only contains the old `latest_gc_cutoff` and
// find_gc_cutoffs will try accessing things below the cutoff. TODO: ideally, this should
// be batched into `schedule_compaction_update`.
let disk_consistent_lsn = self.disk_consistent_lsn.load();
self.schedule_uploads(disk_consistent_lsn, None)?;
self.remote_client
.schedule_compaction_update(&layer_selection, &compact_to)?;

View File

@@ -541,7 +541,6 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
LWLockRelease(lfc_lock);
return found;
}

View File

@@ -827,6 +827,7 @@ pageserver_send(shardno_t shard_no, NeonRequest *request)
{
while (!pageserver_connect(shard_no, shard->n_reconnect_attempts < max_reconnect_attempts ? LOG : ERROR))
{
HandleMainLoopInterrupts();
shard->n_reconnect_attempts += 1;
}
shard->n_reconnect_attempts = 0;

View File

@@ -678,9 +678,6 @@ mod tests {
.await
.unwrap();
// flush the final server message
stream.flush().await.unwrap();
handle.await.unwrap();
}

View File

@@ -10,6 +10,7 @@ use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
@@ -17,7 +18,6 @@ use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {

View File

@@ -27,7 +27,6 @@ use proxy::rate_limiter::{
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::types::RoleName;
use proxy::url::ApiUrl;
@@ -35,6 +34,8 @@ project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use rustls::crypto::ring;
use rustls::RootCertStore;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
@@ -272,9 +273,19 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
// local_proxy won't use TLS to talk to postgres.
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
};

View File

@@ -10,12 +10,12 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use proxy::tls::TlsServerEndPoint;
use rustls::crypto::ring;
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};

View File

@@ -3,7 +3,7 @@ use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::{bail, Context};
use futures::future::Either;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
@@ -24,9 +24,9 @@ use proxy::redis::{elasticache, notifications};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
@@ -637,9 +637,20 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};
let root_store = load_certs()
.context("loading native tls certificates")?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
};
@@ -663,6 +674,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
Ok(config)
}
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,

View File

@@ -17,11 +17,11 @@ use crate::config::ComputeConfig;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use crate::tls::postgres_rustls::MakeRustlsConnect;
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
@@ -271,10 +271,8 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
crate::postgres_rustls::MakeRustlsConnect::new(&compute_config.tls),
&self.hostname,
)
.map_err(|e| {
@@ -326,9 +324,11 @@ impl<P> Drop for Session<P> {
mod tests {
use std::time::Duration;
use rustls::crypto::ring;
use rustls::RootCertStore;
use super::*;
use crate::config::RetryConfig;
use crate::tls::client_config::compute_client_config_with_certs;
fn config() -> ComputeConfig {
let retry = RetryConfig {
@@ -337,9 +337,18 @@ mod tests {
backoff_factor: 2.0,
};
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
}
}

View File

@@ -22,8 +22,8 @@ use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -193,15 +193,11 @@ impl ConnCfg {
let connect_once = |host, port| {
debug!("trying to connect to compute node at {host}:{port}");
connect_with_timeout(host, port).and_then(|stream| async {
let socket_addr = stream.peer_addr()?;
let socket = socket2::SockRef::from(&stream);
// Disable Nagle's algorithm to not introduce latency between
// client and compute.
socket.set_nodelay(true)?;
connect_with_timeout(host, port).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket.set_keepalive(true)?;
Ok((socket_addr, stream))
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};
@@ -225,7 +221,7 @@ impl ConnCfg {
}
}
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = crate::postgres_rustls::RustlsStream<tokio::net::TcpStream>;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
@@ -255,9 +251,8 @@ impl ConnCfg {
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
crate::postgres_rustls::MakeRustlsConnect::new(&config.tls),
host,
)?;

View File

@@ -1,10 +1,18 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use tokio_rustls::TlsConnector;
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
@@ -13,7 +21,6 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host;
pub struct ProxyConfig {
@@ -31,7 +38,7 @@ pub struct ProxyConfig {
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub tls: TlsConnector,
pub timeout: Duration,
}
@@ -52,6 +59,12 @@ pub struct MetricCollectionConfig {
pub backup_metric_collection_config: MetricBackupCollectionConfig,
}
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
@@ -74,6 +87,272 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.

View File

@@ -89,6 +89,7 @@ pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
pub mod postgres_rustls;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
@@ -98,7 +99,6 @@ pub mod scram;
pub mod serverless;
pub mod signals;
pub mod stream;
pub mod tls;
pub mod types;
pub mod url;
pub mod usage_metrics;

View File

@@ -1,10 +1,10 @@
use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
pub use private::RustlsStream;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
mod private {
use std::future::Future;
@@ -18,7 +18,7 @@ mod private {
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::tls::TlsServerEndPoint;
use crate::config::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
@@ -35,14 +35,12 @@ mod private {
}
}
pub struct RustlsConnect(pub RustlsConnectData);
pub struct RustlsConnectData {
pub struct RustlsConnectData<'a> {
pub hostname: ServerName<'static>,
pub connector: TlsConnector,
pub connector: &'a TlsConnector,
}
impl<S> TlsConnect<S> for RustlsConnect
impl<S> TlsConnect<S> for RustlsConnectData<'_>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -52,7 +50,7 @@ mod private {
fn connect(self, stream: S) -> Self::Future {
TlsConnectFuture {
inner: self.0.connector.connect(self.0.hostname, stream),
inner: self.connector.connect(self.hostname, stream),
}
}
}
@@ -124,33 +122,30 @@ mod private {
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
pub struct MakeRustlsConnect<'a> {
pub connector: &'a TlsConnector,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
impl<'a> MakeRustlsConnect<'a> {
/// Creates a new `MakeRustlsConnect` from the provided `TlsConnector`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
pub fn new(connector: &'a TlsConnector) -> Self {
Self { connector }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
impl<'a, S> MakeTlsConnect<S> for MakeRustlsConnect<'a>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type TlsConnect = private::RustlsConnectData<'a>;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
fn make_tls_connect(self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: self.connector,
})
}
}

View File

@@ -8,13 +8,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
use crate::config::TlsConfig;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -13,8 +13,9 @@ use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use rustls::{pki_types, RootCertStore};
use tokio::io::DuplexStream;
use tokio_rustls::TlsConnector;
use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
@@ -22,16 +23,14 @@ use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{CertResolver, ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
};
use crate::error::ErrorKind;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::postgres_rustls::MakeRustlsConnect;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -69,16 +68,16 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: Arc<rustls::ClientConfig>,
config: TlsConnector,
hostname: &'a str,
}
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<'a, S> = <MakeRustlsConnect<'a> as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
fn make_tls_connect(&self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let mk = MakeRustlsConnect::new(&self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(mk, self.hostname)?;
Ok(tls)
}
}
@@ -112,9 +111,22 @@ fn generate_tls_config<'a>(
};
let client_config = {
let config = Arc::new(compute_client_config_with_certs([ca]));
let config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.context("ring should support the default protocol versions")?
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(ca)?;
store
})
.with_no_client_auth();
let config = Arc::new(config);
ClientConfig { config, hostname }
ClientConfig {
config: TlsConnector::from(config),
hostname,
}
};
Ok((client_config, tls_config))
@@ -576,9 +588,18 @@ fn config() -> ComputeConfig {
backoff_factor: 2.0,
};
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
}
}

View File

@@ -40,27 +40,6 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -73,24 +52,6 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
// TODO: change type once the implementation is more fully fledged.
// See e.g. https://github.com/neondatabase/neon/pull/10073.
account_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -204,11 +165,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -221,8 +178,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
@@ -249,15 +204,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
}
}

View File

@@ -50,12 +50,6 @@ impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
// Queue a SASL message for the client.
fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message_noflush(&msg.to_reply())?;
Ok(())
}
}
/// SASL authentication outcome.
@@ -91,7 +85,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
continue;
}
Step::Success(result, reply) => {
self.send_noflush(&ServerMessage::Final(&reply))?;
self.send(&ServerMessage::Final(&reply)).await?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),

View File

@@ -13,6 +13,7 @@ use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -58,14 +59,14 @@ enum ExchangeState {
pub(crate) struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub(crate) fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
@@ -119,7 +120,7 @@ impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
@@ -154,7 +155,7 @@ impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
@@ -167,8 +168,8 @@ impl SaslSentInner {
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack

View File

@@ -77,8 +77,11 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange =
Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -30,6 +30,7 @@ use crate::control_plane::locks::ApiLocks;
use crate::control_plane::CachedNodeInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::postgres_rustls::MakeRustlsConnect;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -514,7 +515,9 @@ impl ConnectMechanism for TokioMechanism {
.connect_timeout(compute_config.timeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(postgres_client::NoTls).await;
let res = config
.connect(MakeRustlsConnect::new(&compute_config.tls))
.await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -560,6 +563,10 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = node_info.config.get_port();
// TODO(conrad): how would we roll-out TLS for these connections?
// Postgres has negotiation, but no such thing for HTTP.
// Assume https, fall back to http (on the same port)?
let res = connect_http2(&host, port, config.timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -5,7 +5,6 @@ use std::task::{ready, Poll};
use futures::future::poll_fn;
use futures::Future;
use postgres_client::tls::NoTlsStream;
use postgres_client::AsyncMessage;
use smallvec::SmallVec;
use tokio::net::TcpStream;
@@ -26,6 +25,7 @@ use super::conn_pool_lib::{
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use crate::postgres_rustls::RustlsStream;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
@@ -58,7 +58,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
ctx: &RequestContext,
conn_info: ConnInfo,
client: C,
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
mut connection: postgres_client::Connection<TcpStream, RustlsStream<tokio::net::TcpStream>>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {

View File

@@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
///

View File

@@ -1,42 +0,0 @@
use std::sync::Arc;
use anyhow::bail;
use rustls::crypto::ring;
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
/// This function is blocking.
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
Ok(
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(load_certs()?)
.with_no_client_auth(),
)
}
#[cfg(test)]
pub fn compute_client_config_with_certs(
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
) -> rustls::ClientConfig {
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(certs);
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(store)
.with_no_client_auth()
}

View File

@@ -1,72 +0,0 @@
pub mod client_config;
pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}

View File

@@ -1,218 +0,0 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::{bail, Context};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}

View File

@@ -9,7 +9,6 @@ default = []
# Enables test-only APIs, incuding failpoints. In particular, enables the `fail_point!` macro,
# which adds some runtime cost to run tests on outage conditions
testing = ["fail/failpoints"]
benchmarking = []
[dependencies]
async-stream.workspace = true
@@ -78,4 +77,3 @@ tracing-subscriber = { workspace = true, features = ["json"] }
[[bench]]
name = "receive_wal"
harness = false
required-features = ["benchmarking"]

View File

@@ -1,18 +1,18 @@
use std::sync::Arc;
use crate::rate_limit::RateLimiter;
use crate::safekeeper::{ProposerAcceptorMessage, ProposerElected, SafeKeeper, TermHistory};
use crate::state::{TimelinePersistentState, TimelineState};
use crate::timeline::{get_timeline_dir, SharedState, StateSK, Timeline};
use crate::timelines_set::TimelinesSet;
use crate::wal_backup::remote_timeline_path;
use crate::{control_file, wal_storage, SafeKeeperConf};
use camino_tempfile::Utf8TempDir;
use safekeeper::rate_limit::RateLimiter;
use safekeeper::safekeeper::{ProposerAcceptorMessage, ProposerElected, SafeKeeper, TermHistory};
use safekeeper::state::{TimelinePersistentState, TimelineState};
use safekeeper::timeline::{get_timeline_dir, SharedState, StateSK, Timeline};
use safekeeper::timelines_set::TimelinesSet;
use safekeeper::wal_backup::remote_timeline_path;
use safekeeper::{control_file, wal_storage, SafeKeeperConf};
use tokio::fs::create_dir_all;
use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
/// A Safekeeper testing or benchmarking environment. Uses a tempdir for storage, removed on drop.
/// A Safekeeper benchmarking environment. Uses a tempdir for storage, removed on drop.
pub struct Env {
/// Whether to enable fsync.
pub fsync: bool,
@@ -21,7 +21,7 @@ pub struct Env {
}
impl Env {
/// Creates a new test or benchmarking environment in a temporary directory. fsync controls whether to
/// Creates a new benchmarking environment in a temporary directory. fsync controls whether to
/// enable fsyncing.
pub fn new(fsync: bool) -> anyhow::Result<Self> {
let tempdir = camino_tempfile::tempdir()?;
@@ -47,7 +47,6 @@ impl Env {
&self,
node_id: NodeId,
ttid: TenantTimelineId,
start_lsn: Lsn,
) -> anyhow::Result<SafeKeeper<control_file::FileStorage, wal_storage::PhysicalStorage>> {
let conf = self.make_conf(node_id);
@@ -68,9 +67,9 @@ impl Env {
safekeeper
.process_msg(&ProposerAcceptorMessage::Elected(ProposerElected {
term: 1,
start_streaming_at: start_lsn,
term_history: TermHistory(vec![(1, start_lsn).into()]),
timeline_start_lsn: start_lsn,
start_streaming_at: Lsn(0),
term_history: TermHistory(vec![(1, Lsn(0)).into()]),
timeline_start_lsn: Lsn(0),
}))
.await?;
@@ -83,13 +82,12 @@ impl Env {
&self,
node_id: NodeId,
ttid: TenantTimelineId,
start_lsn: Lsn,
) -> anyhow::Result<Arc<Timeline>> {
let conf = Arc::new(self.make_conf(node_id));
let timeline_dir = get_timeline_dir(&conf, &ttid);
let remote_path = remote_timeline_path(&ttid)?;
let safekeeper = self.make_safekeeper(node_id, ttid, start_lsn).await?;
let safekeeper = self.make_safekeeper(node_id, ttid).await?;
let shared_state = SharedState::new(StateSK::Loaded(safekeeper));
let timeline = Timeline::new(

View File

@@ -1,7 +1,11 @@
//! WAL ingestion benchmarks.
#[path = "benchutils.rs"]
mod benchutils;
use std::io::Write as _;
use benchutils::Env;
use bytes::BytesMut;
use camino_tempfile::tempfile;
use criterion::{criterion_group, criterion_main, BatchSize, Bencher, Criterion};
@@ -12,7 +16,6 @@ use safekeeper::receive_wal::{self, WalAcceptor};
use safekeeper::safekeeper::{
AcceptorProposerMessage, AppendRequest, AppendRequestHeader, ProposerAcceptorMessage,
};
use safekeeper::test_utils::Env;
use tokio::io::AsyncWriteExt as _;
use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
@@ -73,15 +76,12 @@ fn bench_process_msg(c: &mut Criterion) {
assert!(size >= prefixlen);
let message = vec![0; size - prefixlen];
let walgen = &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message), Lsn(0));
let walgen = &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message));
// Set up the Safekeeper.
let env = Env::new(fsync)?;
let mut safekeeper = runtime.block_on(env.make_safekeeper(
NodeId(1),
TenantTimelineId::generate(),
Lsn(0),
))?;
let mut safekeeper =
runtime.block_on(env.make_safekeeper(NodeId(1), TenantTimelineId::generate()))?;
b.iter_batched_ref(
// Pre-construct WAL records and requests. Criterion will batch them.
@@ -134,8 +134,7 @@ fn bench_wal_acceptor(c: &mut Criterion) {
let runtime = tokio::runtime::Runtime::new()?; // needs multithreaded
let env = Env::new(fsync)?;
let walgen =
&mut WalGenerator::new(LogicalMessageGenerator::new(c"prefix", b"message"), Lsn(0));
let walgen = &mut WalGenerator::new(LogicalMessageGenerator::new(c"prefix", b"message"));
// Create buffered channels that can fit all requests, to avoid blocking on channels.
let (msg_tx, msg_rx) = tokio::sync::mpsc::channel(n);
@@ -146,7 +145,7 @@ fn bench_wal_acceptor(c: &mut Criterion) {
// TODO: WalAcceptor doesn't actually need a full timeline, only
// Safekeeper::process_msg(). Consider decoupling them to simplify the setup.
let tli = env
.make_timeline(NodeId(1), TenantTimelineId::generate(), Lsn(0))
.make_timeline(NodeId(1), TenantTimelineId::generate())
.await?
.wal_residence_guard()
.await?;
@@ -240,7 +239,7 @@ fn bench_wal_acceptor_throughput(c: &mut Criterion) {
assert!(size >= prefixlen);
let message = vec![0; size - prefixlen];
let walgen = &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message), Lsn(0));
let walgen = &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message));
// Construct and spawn the WalAcceptor task.
let env = Env::new(fsync)?;
@@ -250,7 +249,7 @@ fn bench_wal_acceptor_throughput(c: &mut Criterion) {
runtime.block_on(async {
let tli = env
.make_timeline(NodeId(1), TenantTimelineId::generate(), Lsn(0))
.make_timeline(NodeId(1), TenantTimelineId::generate())
.await?
.wal_residence_guard()
.await?;

View File

@@ -51,12 +51,10 @@ use utils::{
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
// Configure jemalloc to sample allocations for profiles every 1 MB (1 << 20).
// TODO: disabled because concurrent CPU profiles cause seg faults. See:
// https://github.com/neondatabase/neon/issues/10225.
//#[allow(non_upper_case_globals)]
//#[export_name = "malloc_conf"]
//pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:20\0";
/// Configure jemalloc to sample allocations for profiles every 1 MB (1 << 20).
#[allow(non_upper_case_globals)]
#[export_name = "malloc_conf"]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:20\0";
const PID_FILE_NAME: &str = "safekeeper.pid";
const ID_FILE_NAME: &str = "safekeeper.id";

View File

@@ -564,7 +564,7 @@ pub fn make_router(
if conf.http_auth.is_some() {
router = router.middleware(auth_middleware(|request| {
const ALLOWLIST_ROUTES: &[&str] =
&["/v1/status", "/metrics", "/profile/cpu", "/profile/heap"];
&["/v1/status", "/metrics", "/profile/cpu", "profile/heap"];
if ALLOWLIST_ROUTES.contains(&request.uri().path()) {
None
} else {

View File

@@ -43,9 +43,6 @@ pub mod wal_reader_stream;
pub mod wal_service;
pub mod wal_storage;
#[cfg(any(test, feature = "benchmarking"))]
pub mod test_utils;
mod timelines_global_map;
use std::sync::Arc;
pub use timelines_global_map::GlobalTimelines;

View File

@@ -94,14 +94,9 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
}
}
let max_next_record_lsn = match max_next_record_lsn {
Some(lsn) => lsn,
None => { continue; }
};
let batch = InterpretedWalRecords {
records,
next_record_lsn: Some(max_next_record_lsn),
next_record_lsn: max_next_record_lsn
};
tx.send(Batch {wal_end_lsn, available_wal_end_lsn, records: batch}).await.unwrap();

View File

@@ -18,7 +18,7 @@ impl DiskWalProposer {
internal_available_lsn: Lsn(0),
prev_lsn: Lsn(0),
disk: BlockStorage::new(),
wal_generator: WalGenerator::new(LogicalMessageGenerator::new(c"", &[]), Lsn(0)),
wal_generator: WalGenerator::new(LogicalMessageGenerator::new(c"", &[])),
}),
})
}

View File

@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::error::Error as _;
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
@@ -7,7 +6,6 @@ use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
use control_plane::local_env::LocalEnv;
use futures::StreamExt;
use hyper::StatusCode;
use pageserver_api::controller_api::AvailabilityZone;
use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId};
use postgres_connection::parse_host_port;
use serde::{Deserialize, Serialize};
@@ -30,9 +28,6 @@ struct UnshardedComputeHookTenant {
// Which node is this tenant attached to
node_id: NodeId,
// The tenant's preferred AZ, so that we may pass this on to the control plane
preferred_az: Option<AvailabilityZone>,
// Must hold this lock to send a notification.
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteState>>>,
}
@@ -41,9 +36,6 @@ struct ShardedComputeHookTenant {
shard_count: ShardCount,
shards: Vec<(ShardNumber, NodeId)>,
// The tenant's preferred AZ, so that we may pass this on to the control plane
preferred_az: Option<AvailabilityZone>,
// Must hold this lock to send a notification. The contents represent
// the last successfully sent notification, and are used to coalesce multiple
// updates by only sending when there is a chance since our last successful send.
@@ -72,24 +64,17 @@ enum ComputeHookTenant {
impl ComputeHookTenant {
/// Construct with at least one shard's information
fn new(
tenant_shard_id: TenantShardId,
stripe_size: ShardStripeSize,
preferred_az: Option<AvailabilityZone>,
node_id: NodeId,
) -> Self {
fn new(tenant_shard_id: TenantShardId, stripe_size: ShardStripeSize, node_id: NodeId) -> Self {
if tenant_shard_id.shard_count.count() > 1 {
Self::Sharded(ShardedComputeHookTenant {
shards: vec![(tenant_shard_id.shard_number, node_id)],
stripe_size,
shard_count: tenant_shard_id.shard_count,
preferred_az,
send_lock: Arc::default(),
})
} else {
Self::Unsharded(UnshardedComputeHookTenant {
node_id,
preferred_az,
send_lock: Arc::default(),
})
}
@@ -135,20 +120,15 @@ impl ComputeHookTenant {
/// Set one shard's location. If stripe size or shard count have changed, Self is reset
/// and drops existing content.
fn update(&mut self, shard_update: ShardUpdate) {
let tenant_shard_id = shard_update.tenant_shard_id;
let node_id = shard_update.node_id;
let stripe_size = shard_update.stripe_size;
let preferred_az = shard_update.preferred_az;
fn update(
&mut self,
tenant_shard_id: TenantShardId,
stripe_size: ShardStripeSize,
node_id: NodeId,
) {
match self {
Self::Unsharded(unsharded_tenant) if tenant_shard_id.shard_count.count() == 1 => {
unsharded_tenant.node_id = node_id;
if unsharded_tenant.preferred_az.as_ref()
!= preferred_az.as_ref().map(|az| az.as_ref())
{
unsharded_tenant.preferred_az = preferred_az.map(|az| az.as_ref().clone());
}
unsharded_tenant.node_id = node_id
}
Self::Sharded(sharded_tenant)
if sharded_tenant.stripe_size == stripe_size
@@ -166,21 +146,10 @@ impl ComputeHookTenant {
.push((tenant_shard_id.shard_number, node_id));
sharded_tenant.shards.sort_by_key(|s| s.0)
}
if sharded_tenant.preferred_az.as_ref()
!= preferred_az.as_ref().map(|az| az.as_ref())
{
sharded_tenant.preferred_az = preferred_az.map(|az| az.as_ref().clone());
}
}
_ => {
// Shard count changed: reset struct.
*self = Self::new(
tenant_shard_id,
stripe_size,
preferred_az.map(|az| az.into_owned()),
node_id,
);
*self = Self::new(tenant_shard_id, stripe_size, node_id);
}
}
}
@@ -196,7 +165,6 @@ struct ComputeHookNotifyRequestShard {
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ComputeHookNotifyRequest {
tenant_id: TenantId,
preferred_az: Option<String>,
stripe_size: Option<ShardStripeSize>,
shards: Vec<ComputeHookNotifyRequestShard>,
}
@@ -270,10 +238,6 @@ impl ComputeHookTenant {
node_id: unsharded_tenant.node_id,
}],
stripe_size: None,
preferred_az: unsharded_tenant
.preferred_az
.as_ref()
.map(|az| az.0.clone()),
}),
Self::Sharded(sharded_tenant)
if sharded_tenant.shards.len() == sharded_tenant.shard_count.count() as usize =>
@@ -289,7 +253,6 @@ impl ComputeHookTenant {
})
.collect(),
stripe_size: Some(sharded_tenant.stripe_size),
preferred_az: sharded_tenant.preferred_az.as_ref().map(|az| az.0.clone()),
})
}
Self::Sharded(sharded_tenant) => {
@@ -350,17 +313,6 @@ pub(super) struct ComputeHook {
client: reqwest::Client,
}
/// Callers may give us a list of these when asking us to send a bulk batch
/// of notifications in the background. This is a 'notification' in the sense of
/// other code notifying us of a shard's status, rather than being the final notification
/// that we send upwards to the control plane for the whole tenant.
pub(crate) struct ShardUpdate<'a> {
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) node_id: NodeId,
pub(crate) stripe_size: ShardStripeSize,
pub(crate) preferred_az: Option<Cow<'a, AvailabilityZone>>,
}
impl ComputeHook {
pub(super) fn new(config: Config) -> Self {
let authorization_header = config
@@ -411,7 +363,6 @@ impl ComputeHook {
tenant_id,
shards,
stripe_size,
preferred_az: _preferred_az,
} = reconfigure_request;
let compute_pageservers = shards
@@ -552,30 +503,24 @@ impl ComputeHook {
}
/// Synchronous phase: update the per-tenant state for the next intended notification
fn notify_prepare(&self, shard_update: ShardUpdate) -> MaybeSendResult {
fn notify_prepare(
&self,
tenant_shard_id: TenantShardId,
node_id: NodeId,
stripe_size: ShardStripeSize,
) -> MaybeSendResult {
let mut state_locked = self.state.lock().unwrap();
use std::collections::hash_map::Entry;
let tenant_shard_id = shard_update.tenant_shard_id;
let tenant = match state_locked.entry(tenant_shard_id.tenant_id) {
Entry::Vacant(e) => {
let ShardUpdate {
tenant_shard_id,
node_id,
stripe_size,
preferred_az,
} = shard_update;
e.insert(ComputeHookTenant::new(
tenant_shard_id,
stripe_size,
preferred_az.map(|az| az.into_owned()),
node_id,
))
}
Entry::Vacant(e) => e.insert(ComputeHookTenant::new(
tenant_shard_id,
stripe_size,
node_id,
)),
Entry::Occupied(e) => {
let tenant = e.into_mut();
tenant.update(shard_update);
tenant.update(tenant_shard_id, stripe_size, node_id);
tenant
}
};
@@ -663,14 +608,13 @@ impl ComputeHook {
/// if something failed.
pub(super) fn notify_background(
self: &Arc<Self>,
notifications: Vec<ShardUpdate>,
notifications: Vec<(TenantShardId, NodeId, ShardStripeSize)>,
result_tx: tokio::sync::mpsc::Sender<Result<(), (TenantShardId, NotifyError)>>,
cancel: &CancellationToken,
) {
let mut maybe_sends = Vec::new();
for shard_update in notifications {
let tenant_shard_id = shard_update.tenant_shard_id;
let maybe_send_result = self.notify_prepare(shard_update);
for (tenant_shard_id, node_id, stripe_size) in notifications {
let maybe_send_result = self.notify_prepare(tenant_shard_id, node_id, stripe_size);
maybe_sends.push((tenant_shard_id, maybe_send_result))
}
@@ -734,14 +678,15 @@ impl ComputeHook {
/// periods, but we don't retry forever. The **caller** is responsible for handling failures and
/// ensuring that they eventually call again to ensure that the compute is eventually notified of
/// the proper pageserver nodes for a tenant.
#[tracing::instrument(skip_all, fields(tenant_id=%shard_update.tenant_shard_id.tenant_id, shard_id=%shard_update.tenant_shard_id.shard_slug(), node_id))]
pub(super) async fn notify<'a>(
#[tracing::instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), node_id))]
pub(super) async fn notify(
&self,
shard_update: ShardUpdate<'a>,
tenant_shard_id: TenantShardId,
node_id: NodeId,
stripe_size: ShardStripeSize,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let tenant_shard_id = shard_update.tenant_shard_id;
let maybe_send_result = self.notify_prepare(shard_update);
let maybe_send_result = self.notify_prepare(tenant_shard_id, node_id, stripe_size);
self.notify_execute(maybe_send_result, tenant_shard_id, cancel)
.await
}
@@ -794,7 +739,6 @@ pub(crate) mod tests {
shard_number: ShardNumber(0),
},
ShardStripeSize(12345),
None,
NodeId(1),
);
@@ -821,32 +765,30 @@ pub(crate) mod tests {
// Writing the first shard of a multi-sharded situation (i.e. in a split)
// resets the tenant state and puts it in an non-notifying state (need to
// see all shards)
tenant_state.update(ShardUpdate {
tenant_shard_id: TenantShardId {
tenant_state.update(
TenantShardId {
tenant_id,
shard_count: ShardCount::new(2),
shard_number: ShardNumber(1),
},
stripe_size: ShardStripeSize(32768),
preferred_az: None,
node_id: NodeId(1),
});
ShardStripeSize(32768),
NodeId(1),
);
assert!(matches!(
tenant_state.maybe_send(tenant_id, None),
MaybeSendResult::Noop
));
// Writing the second shard makes it ready to notify
tenant_state.update(ShardUpdate {
tenant_shard_id: TenantShardId {
tenant_state.update(
TenantShardId {
tenant_id,
shard_count: ShardCount::new(2),
shard_number: ShardNumber(0),
},
stripe_size: ShardStripeSize(32768),
preferred_az: None,
node_id: NodeId(1),
});
ShardStripeSize(32768),
NodeId(1),
);
let send_result = tenant_state.maybe_send(tenant_id, None);
let MaybeSendResult::Transmit((request, mut guard)) = send_result else {

View File

@@ -11,7 +11,6 @@ use diesel::Connection;
use itertools::Itertools;
use pageserver_api::controller_api::AvailabilityZone;
use pageserver_api::controller_api::MetadataHealthRecord;
use pageserver_api::controller_api::SafekeeperDescribeResponse;
use pageserver_api::controller_api::ShardSchedulingPolicy;
use pageserver_api::controller_api::{NodeSchedulingPolicy, PlacementPolicy};
use pageserver_api::models::TenantConfig;
@@ -1242,18 +1241,6 @@ impl SafekeeperPersistence {
availability_zone_id: &self.availability_zone_id,
}
}
pub(crate) fn as_describe_response(&self) -> SafekeeperDescribeResponse {
// omit the `active` flag on purpose: it is deprecated.
SafekeeperDescribeResponse {
id: NodeId(self.id as u64),
region_id: self.region_id.clone(),
version: self.version,
host: self.host.clone(),
port: self.port,
http_port: self.http_port,
availability_zone_id: self.availability_zone_id.clone(),
}
}
}
#[derive(Insertable, AsChangeset)]

View File

@@ -1,14 +1,13 @@
use crate::pageserver_client::PageserverClient;
use crate::persistence::Persistence;
use crate::{compute_hook, service};
use pageserver_api::controller_api::{AvailabilityZone, PlacementPolicy};
use crate::service;
use pageserver_api::controller_api::PlacementPolicy;
use pageserver_api::models::{
LocationConfig, LocationConfigMode, LocationConfigSecondary, TenantConfig,
};
use pageserver_api::shard::{ShardIdentity, TenantShardId};
use pageserver_client::mgmt_api;
use reqwest::StatusCode;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
@@ -46,7 +45,6 @@ pub(super) struct Reconciler {
pub(crate) reconciler_config: ReconcilerConfig,
pub(crate) config: TenantConfig,
pub(crate) preferred_az: Option<AvailabilityZone>,
/// Observed state from the point of view of the reconciler.
/// This gets updated as the reconciliation makes progress.
@@ -836,12 +834,9 @@ impl Reconciler {
let result = self
.compute_hook
.notify(
compute_hook::ShardUpdate {
tenant_shard_id: self.tenant_shard_id,
node_id: node.get_id(),
stripe_size: self.shard.stripe_size,
preferred_az: self.preferred_az.as_ref().map(Cow::Borrowed),
},
self.tenant_shard_id,
node.get_id(),
self.shard.stripe_size,
&self.cancel,
)
.await;

View File

@@ -18,7 +18,7 @@ use crate::{
background_node_operations::{
Drain, Fill, Operation, OperationError, OperationHandler, MAX_RECONCILES_PER_OPERATION,
},
compute_hook::{self, NotifyError},
compute_hook::NotifyError,
drain_utils::{self, TenantShardDrain, TenantShardIterator},
id_lock_map::{trace_exclusive_lock, trace_shared_lock, IdLockMap, TracingExclusiveGuard},
leadership::Leadership,
@@ -46,11 +46,10 @@ use pageserver_api::{
controller_api::{
AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability,
NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, ShardsPreferredAzsRequest,
ShardsPreferredAzsResponse, TenantCreateRequest, TenantCreateResponse,
TenantCreateResponseShard, TenantDescribeResponse, TenantDescribeResponseShard,
TenantLocateResponse, TenantPolicyRequest, TenantShardMigrateRequest,
TenantShardMigrateResponse,
ShardSchedulingPolicy, ShardsPreferredAzsRequest, ShardsPreferredAzsResponse,
TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard,
TenantDescribeResponse, TenantDescribeResponseShard, TenantLocateResponse,
TenantPolicyRequest, TenantShardMigrateRequest, TenantShardMigrateResponse,
},
models::{
SecondaryProgress, TenantConfigPatchRequest, TenantConfigRequest,
@@ -657,14 +656,11 @@ impl Service {
// emit a compute notification for this. In the case where our observed state does not
// yet match our intent, we will eventually reconcile, and that will emit a compute notification.
if let Some(attached_at) = tenant_shard.stably_attached() {
compute_notifications.push(compute_hook::ShardUpdate {
tenant_shard_id: *tenant_shard_id,
node_id: attached_at,
stripe_size: tenant_shard.shard.stripe_size,
preferred_az: tenant_shard
.preferred_az()
.map(|az| Cow::Owned(az.clone())),
});
compute_notifications.push((
*tenant_shard_id,
attached_at,
tenant_shard.shard.stripe_size,
));
}
}
}
@@ -4790,15 +4786,7 @@ impl Service {
for (child_id, child_ps, stripe_size) in child_locations {
if let Err(e) = self
.compute_hook
.notify(
compute_hook::ShardUpdate {
tenant_shard_id: child_id,
node_id: child_ps,
stripe_size,
preferred_az: preferred_az_id.as_ref().map(Cow::Borrowed),
},
&self.cancel,
)
.notify(child_id, child_ps, stripe_size, &self.cancel)
.await
{
tracing::warn!("Failed to update compute of {}->{} during split, proceeding anyway to complete split ({e})",
@@ -7170,24 +7158,15 @@ impl Service {
pub(crate) async fn safekeepers_list(
&self,
) -> Result<Vec<SafekeeperDescribeResponse>, DatabaseError> {
Ok(self
.persistence
.list_safekeepers()
.await?
.into_iter()
.map(|v| v.as_describe_response())
.collect::<Vec<_>>())
) -> Result<Vec<crate::persistence::SafekeeperPersistence>, DatabaseError> {
self.persistence.list_safekeepers().await
}
pub(crate) async fn get_safekeeper(
&self,
id: i64,
) -> Result<SafekeeperDescribeResponse, DatabaseError> {
self.persistence
.safekeeper_get(id)
.await
.map(|v| v.as_describe_response())
) -> Result<crate::persistence::SafekeeperPersistence, DatabaseError> {
self.persistence.safekeeper_get(id).await
}
pub(crate) async fn upsert_safekeeper(

View File

@@ -1198,7 +1198,6 @@ impl TenantShard {
detach,
reconciler_config,
config: self.config.clone(),
preferred_az: self.preferred_az_id.clone(),
observed: self.observed.clone(),
original_observed: self.observed.clone(),
compute_hook: compute_hook.clone(),

View File

@@ -310,7 +310,7 @@ pub(crate) enum BlobDataParseResult {
index_part_generation: Generation,
s3_layers: HashSet<(LayerName, Generation)>,
},
/// The remains of an uncleanly deleted Timeline or aborted timeline creation(e.g. an initdb archive only, or some layer without an index)
/// The remains of a deleted Timeline (i.e. an initdb archive only)
Relic,
Incorrect {
errors: Vec<String>,
@@ -346,7 +346,7 @@ pub(crate) async fn list_timeline_blobs(
match res {
ListTimelineBlobsResult::Ready(data) => Ok(data),
ListTimelineBlobsResult::MissingIndexPart(_) => {
// Retry if listing raced with removal of an index
// Retry if index is missing.
let data = list_timeline_blobs_impl(remote_client, id, root_target)
.await?
.into_data();
@@ -358,7 +358,7 @@ pub(crate) async fn list_timeline_blobs(
enum ListTimelineBlobsResult {
/// Blob data is ready to be intepreted.
Ready(RemoteTimelineBlobData),
/// The listing contained an index but when we tried to fetch it, we couldn't
/// List timeline blobs has layer files but is missing [`IndexPart`].
MissingIndexPart(RemoteTimelineBlobData),
}
@@ -467,19 +467,19 @@ async fn list_timeline_blobs_impl(
match index_part_object.as_ref() {
Some(selected) => index_part_keys.retain(|k| k != selected),
None => {
// This case does not indicate corruption, but it should be very unusual. It can
// happen if:
// - timeline creation is in progress (first layer is written before index is written)
// - timeline deletion happened while a stale pageserver was still attached, it might upload
// a layer after the deletion is done.
tracing::info!(
// It is possible that the branch gets deleted after we got some layer files listed
// and we no longer have the index file in the listing.
errors.push(
"S3 list response got no index_part.json file but still has layer files"
.to_string(),
);
return Ok(ListTimelineBlobsResult::Ready(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Relic,
unused_index_keys: index_part_keys,
unknown_keys,
}));
return Ok(ListTimelineBlobsResult::MissingIndexPart(
RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Incorrect { errors, s3_layers },
unused_index_keys: index_part_keys,
unknown_keys,
},
));
}
}

View File

@@ -8,7 +8,6 @@ pytest_plugins = (
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
"fixtures.paths",
"fixtures.compute_migrations",
"fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import pytest
from fixtures.paths import BASE_DIR
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
COMPUTE_MIGRATIONS_DIR = BASE_DIR / "compute_tools" / "src" / "migrations"
COMPUTE_MIGRATIONS_TEST_DIR = COMPUTE_MIGRATIONS_DIR / "tests"
COMPUTE_MIGRATIONS = sorted(next(os.walk(COMPUTE_MIGRATIONS_DIR))[2])
NUM_COMPUTE_MIGRATIONS = len(COMPUTE_MIGRATIONS)
@pytest.fixture(scope="session")
def compute_migrations_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations directory.
"""
yield COMPUTE_MIGRATIONS_DIR
@pytest.fixture(scope="session")
def compute_migrations_test_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations test directory.
"""
yield COMPUTE_MIGRATIONS_TEST_DIR

View File

@@ -55,17 +55,3 @@ class EndpointHttpClient(requests.Session):
res = self.get(f"http://localhost:{self.port}/metrics")
res.raise_for_status()
return res.text
def configure_failpoints(self, *args: tuple[str, str]) -> None:
body: list[dict[str, str]] = []
for fp in args:
body.append(
{
"name": fp[0],
"action": fp[1],
}
)
res = self.post(f"http://localhost:{self.port}/failpoints", json=body)
res.raise_for_status()

View File

@@ -170,7 +170,6 @@ PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = (
"pageserver_evictions_with_low_residence_duration_total",
"pageserver_aux_file_estimated_size",
"pageserver_valid_lsn_lease_count",
"pageserver_flush_wait_upload_seconds",
counter("pageserver_tenant_throttling_count_accounted_start"),
counter("pageserver_tenant_throttling_count_accounted_finish"),
counter("pageserver_tenant_throttling_wait_usecs_sum"),

View File

@@ -522,15 +522,14 @@ class NeonLocalCli(AbstractNeonCli):
safekeepers: list[int] | None = None,
remote_ext_config: str | None = None,
pageserver_id: int | None = None,
allow_multiple: bool = False,
allow_multiple=False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> subprocess.CompletedProcess[str]:
args = [
"endpoint",
"start",
]
extra_env_vars = env or {}
extra_env_vars = {}
if basebackup_request_tries is not None:
extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries)
if remote_ext_config is not None:

View File

@@ -54,7 +54,6 @@ from fixtures.common_types import (
TimelineArchivalState,
TimelineId,
)
from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.h2server import H2Server
from fixtures.log_helper import log
@@ -135,9 +134,6 @@ DEFAULT_BRANCH_NAME: str = "main"
BASE_PORT: int = 15000
# By default we create pageservers with this phony AZ
DEFAULT_AZ_ID: str = "us-east-2a"
@pytest.fixture(scope="session")
def neon_api_key() -> str:
@@ -1097,7 +1093,7 @@ class NeonEnv:
"pg_auth_type": pg_auth_type,
"http_auth_type": http_auth_type,
# Default which can be overriden with `NeonEnvBuilder.pageserver_config_override`
"availability_zone": DEFAULT_AZ_ID,
"availability_zone": "us-east-2a",
# Disable pageserver disk syncs in tests: when running tests concurrently, this avoids
# the pageserver taking a long time to start up due to syncfs flushing other tests' data
"no_sync": True,
@@ -3856,7 +3852,6 @@ class Endpoint(PgProtocol, LogUtils):
safekeepers: list[int] | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> Self:
"""
Start the Postgres instance.
@@ -3877,7 +3872,6 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
basebackup_request_tries=basebackup_request_tries,
env=env,
)
self._running.release(1)
self.log_config_value("shared_buffers")
@@ -3991,17 +3985,14 @@ class Endpoint(PgProtocol, LogUtils):
log.info("Updating compute spec to: %s", json.dumps(data_dict, indent=4))
json.dump(data_dict, file, indent=4)
def wait_for_migrations(self, wait_for: int = NUM_COMPUTE_MIGRATIONS) -> None:
"""
Wait for all compute migrations to be ran. Remember that migrations only
run if "pg_skip_catalog_updates" is set in the compute spec to false.
"""
# Please note: Migrations only run if pg_skip_catalog_updates is false
def wait_for_migrations(self, num_migrations: int = 11):
with self.cursor() as cur:
def check_migrations_done():
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id: int = cur.fetchall()[0][0]
assert migration_id >= wait_for
assert migration_id >= num_migrations
wait_until(check_migrations_done)

View File

@@ -738,18 +738,6 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res_json = res.json()
assert res_json is None
def timeline_compact_info(
self,
tenant_id: TenantId | TenantShardId,
timeline_id: TimelineId,
) -> Any:
res = self.get(
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/compact",
)
self.verbose_error(res)
res_json = res.json()
return res_json
def timeline_compact(
self,
tenant_id: TenantId | TenantShardId,
@@ -761,6 +749,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
enhanced_gc_bottom_most_compaction=False,
body: dict[str, Any] | None = None,
):
self.is_testing_enabled_or_skip()
query = {}
if force_repartition:
query["force_repartition"] = "true"

View File

@@ -21,8 +21,8 @@ if TYPE_CHECKING:
BASE_DIR = Path(__file__).parents[2]
DEFAULT_OUTPUT_DIR: str = "test_output"
COMPUTE_CONFIG_DIR = BASE_DIR / "compute" / "etc"
DEFAULT_OUTPUT_DIR: str = "test_output"
def get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str | None = None) -> Path:

View File

@@ -1,199 +0,0 @@
-- create a schema that simulates Neon control plane operations table
-- however use partitioned operations tables with many (e.g. 500) child partition tables per table
-- in summary we create multiple of these partitioned operations tables (with 500 childs each) - until we reach the requested number of tables
-- first we need some other tables that can be referenced by the operations table
-- Table for branches
CREATE TABLE public.branches (
id text PRIMARY KEY
);
-- Table for endpoints
CREATE TABLE public.endpoints (
id text PRIMARY KEY
);
-- Table for projects
CREATE TABLE public.projects (
id text PRIMARY KEY
);
INSERT INTO public.branches (id)
VALUES ('branch_1');
-- Insert one row into endpoints
INSERT INTO public.endpoints (id)
VALUES ('endpoint_1');
-- Insert one row into projects
INSERT INTO public.projects (id)
VALUES ('project_1');
-- now we create a procedure that can create n operations tables
-- we do that in a procedure to save roundtrip latency when scaling the test to many tables
-- prefix is the base table name, e.g. 'operations_scale_1000' if we create 1000 tables
CREATE OR REPLACE PROCEDURE create_partitioned_tables(prefix text, n INT)
LANGUAGE plpgsql AS $$
DECLARE
table_name TEXT; -- Variable to hold table names dynamically
i INT; -- Counter for the loop
BEGIN
-- Loop to create n partitioned tables
FOR i IN 1..n LOOP
table_name := format('%s_%s', prefix, i);
-- Create the partitioned table
EXECUTE format(
'CREATE TABLE public.%s (
project_id character varying NOT NULL,
id uuid NOT NULL,
status integer,
action character varying NOT NULL,
error character varying,
created_at timestamp with time zone NOT NULL DEFAULT now(),
updated_at timestamp with time zone NOT NULL DEFAULT now(),
spec jsonb,
retry_at timestamp with time zone,
failures_count integer DEFAULT 0,
metadata jsonb NOT NULL DEFAULT ''{}''::jsonb,
executor_id text NOT NULL,
attempt_duration_ms integer,
metrics jsonb DEFAULT ''{}''::jsonb,
branch_id text,
endpoint_id text,
next_operation_id uuid,
compute_id text,
connection_attempt_at timestamp with time zone,
concurrency_key text,
queue_id text,
CONSTRAINT %s_pkey PRIMARY KEY (id, created_at),
CONSTRAINT %s_branch_id_fk FOREIGN KEY (branch_id) REFERENCES branches(id) ON DELETE CASCADE,
CONSTRAINT %s_endpoint_id_fk FOREIGN KEY (endpoint_id) REFERENCES endpoints(id) ON DELETE CASCADE,
CONSTRAINT %s_next_operation_id_fk FOREIGN KEY (next_operation_id, created_at) REFERENCES %s(id, created_at),
CONSTRAINT %s_project_id_fk FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
) PARTITION BY RANGE (created_at)',
table_name, table_name, table_name, table_name, table_name, table_name, table_name
);
-- Add indexes for the partitioned table
EXECUTE format('CREATE INDEX index_%s_on_next_operation_id ON public.%s (next_operation_id)', table_name, table_name);
EXECUTE format('CREATE INDEX index_%s_on_project_id ON public.%s (project_id)', table_name, table_name);
EXECUTE format('CREATE INDEX %s_branch_id ON public.%s (branch_id)', table_name, table_name);
EXECUTE format('CREATE INDEX %s_branch_id_created_idx ON public.%s (branch_id, created_at)', table_name, table_name);
EXECUTE format('CREATE INDEX %s_created_at_idx ON public.%s (created_at)', table_name, table_name);
EXECUTE format('CREATE INDEX %s_created_at_project_id_id_cond_idx ON public.%s (created_at, project_id, id)', table_name, table_name);
EXECUTE format('CREATE INDEX %s_endpoint_id ON public.%s (endpoint_id)', table_name, table_name);
EXECUTE format(
'CREATE INDEX %s_for_redo_worker_idx ON public.%s (executor_id) WHERE status <> 1',
table_name, table_name
);
EXECUTE format(
'CREATE INDEX %s_project_id_status_index ON public.%s ((project_id::text), status)',
table_name, table_name
);
EXECUTE format(
'CREATE INDEX %s_status_not_finished ON public.%s (status) WHERE status <> 1',
table_name, table_name
);
EXECUTE format('CREATE INDEX %s_updated_at_desc_idx ON public.%s (updated_at DESC)', table_name, table_name);
EXECUTE format(
'CREATE INDEX %s_with_failures ON public.%s (failures_count) WHERE failures_count > 0',
table_name, table_name
);
END LOOP;
END;
$$;
-- next we create a procedure that can add the child partitions (one per day) to each of the operations tables
CREATE OR REPLACE PROCEDURE create_operations_partitions(
table_name TEXT,
start_date DATE,
end_date DATE
)
LANGUAGE plpgsql AS $$
DECLARE
partition_date DATE;
partition_name TEXT;
counter INT := 0; -- Counter to track the number of tables created in the current transaction
BEGIN
partition_date := start_date;
-- Create partitions in batches
WHILE partition_date < end_date LOOP
partition_name := format('%s_%s', table_name, to_char(partition_date,'YYYY_MM_DD'));
EXECUTE format(
'CREATE TABLE IF NOT EXISTS public.%s PARTITION OF public.%s
FOR VALUES FROM (''%s'') TO (''%s'')',
partition_name,
table_name,
partition_date,
partition_date + INTERVAL '1 day'
);
counter := counter + 1;
-- Commit and reset counter after every 100 partitions
IF counter >= 100 THEN
COMMIT;
counter := 0; -- Reset the counter
END IF;
-- Advance to the next day
partition_date := partition_date + INTERVAL '1 day';
END LOOP;
-- Final commit for remaining partitions
IF counter > 0 THEN
COMMIT;
END IF;
-- Insert synthetic rows into each partition
EXECUTE format(
'INSERT INTO %I (
project_id,
branch_id,
endpoint_id,
id,
status,
action,
created_at,
updated_at,
spec,
metadata,
executor_id,
failures_count
)
SELECT
''project_1'', -- project_id
''branch_1'', -- branch_id
''endpoint_1'', -- endpoint_id
''e8bba687-0df9-4291-bfcd-7d5f6aa7c158'', -- unique id
1, -- status
''SYNTHETIC_ACTION'', -- action
gs::timestamp + interval ''0 ms'', -- created_at
gs::timestamp + interval ''1 minute'', -- updated_at
''{"key": "value"}'', -- spec (JSONB)
''{"metadata_key": "metadata_value"}'', -- metadata (JSONB)
''executor_1'', -- executor_id
0 -- failures_count
FROM generate_series(%L, %L::DATE - INTERVAL ''1 day'', INTERVAL ''1 day'') AS gs',
table_name, start_date, end_date
);
-- Commit the inserted rows
COMMIT;
END;
$$;
-- we can now create partitioned tables using something like
-- CALL create_partitioned_tables('operations_scale_1000' ,10);
-- and we can create the child partitions for a table using something like
-- CALL create_operations_partitions(
-- 'operations_scale_1000_1',
-- '2000-01-01', -- Start date
-- ('2000-01-01'::DATE + INTERVAL '1 day' * 500)::DATE -- End date (start date + number of days)
-- );

View File

@@ -22,7 +22,7 @@ def gc_feedback_impl(neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchma
"checkpoint_distance": f"{1024 ** 2}",
"compaction_target_size": f"{1024 ** 2}",
# set PITR interval to be small, so we can do GC
"pitr_interval": "10 s",
"pitr_interval": "60 s",
# "compaction_threshold": "3",
# "image_creation_threshold": "2",
}
@@ -32,7 +32,6 @@ def gc_feedback_impl(neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchma
n_steps = 10
n_update_iters = 100
step_size = 10000
branch_created = 0
with endpoint.cursor() as cur:
cur.execute("SET statement_timeout='1000s'")
cur.execute(
@@ -67,7 +66,6 @@ def gc_feedback_impl(neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchma
if mode == "with_snapshots":
if step == n_steps / 2:
env.create_branch("child")
branch_created += 1
max_num_of_deltas_above_image = 0
max_total_num_of_deltas = 0
@@ -144,15 +142,6 @@ def gc_feedback_impl(neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchma
with layer_map_path.open("w") as f:
f.write(json.dumps(client.timeline_layer_map_info(tenant_id, timeline_id)))
# We should have collected all garbage
if mode == "normal":
# in theory we should get physical size ~= logical size, but given that gc interval is 10s,
# and the layer has indexes that might contribute to the fluctuation, we allow a small margin
# of 1 here, and the end ratio we are asserting is 1 (margin) + 1 (expected) = 2.
assert physical_size / logical_size < 2
elif mode == "with_snapshots":
assert physical_size / logical_size < (2 + branch_created)
@pytest.mark.timeout(10000)
def test_gc_feedback(neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker):

Some files were not shown because too many files have changed in this diff Show More