From f64eb0cbaf8edf8b9cf83533146295fb0f1a0a45 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Thu, 5 Jun 2025 17:02:38 +0200 Subject: [PATCH 01/12] Remove the Flaky Test computed-columns from postgis v16 (#12132) ## Problem The `computed_columns` test assumes that computed columns are always faster than the request itself. However, this is not always the case on Neon, which can lead to flaky results. ## Summary of changes The `computed_columns` test is excluded from the PostGIS test for PostgreSQL v16, accompanied by related patch refactoring. --- .../ext-src/postgis-src/neon-test.sh | 9 ++--- ...de-test.patch => postgis-common-v16.patch} | 16 +++++++++ .../postgis-src/postgis-common-v17.patch | 35 +++++++++++++++++++ .../postgis-src/postgis-regular-v16.patch | 14 +------- .../postgis-src/postgis-regular-v17.patch | 12 +------ .../ext-src/postgis-src/regular-test.sh | 4 +-- 6 files changed, 58 insertions(+), 32 deletions(-) rename docker-compose/ext-src/postgis-src/{postgis-no-upgrade-test.patch => postgis-common-v16.patch} (61%) create mode 100644 docker-compose/ext-src/postgis-src/postgis-common-v17.patch diff --git a/docker-compose/ext-src/postgis-src/neon-test.sh b/docker-compose/ext-src/postgis-src/neon-test.sh index 2866649a1b..13df1ec9d1 100755 --- a/docker-compose/ext-src/postgis-src/neon-test.sh +++ b/docker-compose/ext-src/postgis-src/neon-test.sh @@ -1,9 +1,6 @@ -#!/bin/bash +#!/bin/sh set -ex cd "$(dirname "$0")" -if [[ ${PG_VERSION} = v17 ]]; then - sed -i '/computed_columns/d' regress/core/tests.mk -fi -patch -p1 =" 120),1) +- TESTS += \ +- $(top_srcdir)/regress/core/computed_columns +-endif +- + ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) + # GEOS-3.7 adds: + # ST_FrechetDistance diff --git a/regress/runtest.mk b/regress/runtest.mk index c051f03..010e493 100644 --- a/regress/runtest.mk diff --git a/docker-compose/ext-src/postgis-src/postgis-common-v17.patch b/docker-compose/ext-src/postgis-src/postgis-common-v17.patch new file mode 100644 index 0000000000..0b8978281e --- /dev/null +++ b/docker-compose/ext-src/postgis-src/postgis-common-v17.patch @@ -0,0 +1,35 @@ +diff --git a/regress/core/tests.mk b/regress/core/tests.mk +index 9e05244..90987df 100644 +--- a/regress/core/tests.mk ++++ b/regress/core/tests.mk +@@ -143,8 +143,7 @@ TESTS += \ + $(top_srcdir)/regress/core/oriented_envelope \ + $(top_srcdir)/regress/core/point_coordinates \ + $(top_srcdir)/regress/core/out_geojson \ +- $(top_srcdir)/regress/core/wrapx \ +- $(top_srcdir)/regress/core/computed_columns ++ $(top_srcdir)/regress/core/wrapx + + # Slow slow tests + TESTS_SLOW = \ +diff --git a/regress/runtest.mk b/regress/runtest.mk +index 4b95b7e..449d5a2 100644 +--- a/regress/runtest.mk ++++ b/regress/runtest.mk +@@ -24,16 +24,6 @@ check-regress: + + @POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS) + +- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \ +- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \ +- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \ +- --upgrade \ +- $(RUNTESTFLAGS) \ +- $(RUNTESTFLAGS_INTERNAL) \ +- $(TESTS); \ +- else \ +- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \ +- fi + + check-long: + $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW) diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch index 2fd214c534..e7f01ad288 100644 --- a/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch @@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644 DROP SCHEMA tm CASCADE; + diff --git a/regress/core/tests.mk b/regress/core/tests.mk -index 3abd7bc..94903c3 100644 +index 64a9254..94903c3 100644 --- a/regress/core/tests.mk +++ b/regress/core/tests.mk @@ -23,7 +23,6 @@ current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) @@ -160,18 +160,6 @@ index 3abd7bc..94903c3 100644 $(top_srcdir)/regress/core/wkb \ $(top_srcdir)/regress/core/wkt \ $(top_srcdir)/regress/core/wmsservers \ -@@ -144,11 +140,6 @@ TESTS_SLOW = \ - $(top_srcdir)/regress/core/concave_hull_hard \ - $(top_srcdir)/regress/core/knn_recheck - --ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1) -- TESTS += \ -- $(top_srcdir)/regress/core/computed_columns --endif -- - ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) - # GEOS-3.7 adds: - # ST_FrechetDistance diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk index 1fc77ac..c3cb9de 100644 --- a/regress/loader/tests.mk diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch index f4a9d83478..ae76e559df 100644 --- a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch @@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644 DROP SCHEMA tm CASCADE; + diff --git a/regress/core/tests.mk b/regress/core/tests.mk -index 9e05244..a63a3e1 100644 +index 90987df..74fe3f1 100644 --- a/regress/core/tests.mk +++ b/regress/core/tests.mk @@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170 @@ -168,16 +168,6 @@ index 9e05244..a63a3e1 100644 $(top_srcdir)/regress/core/wkb \ $(top_srcdir)/regress/core/wkt \ $(top_srcdir)/regress/core/wmsservers \ -@@ -143,8 +139,7 @@ TESTS += \ - $(top_srcdir)/regress/core/oriented_envelope \ - $(top_srcdir)/regress/core/point_coordinates \ - $(top_srcdir)/regress/core/out_geojson \ -- $(top_srcdir)/regress/core/wrapx \ -- $(top_srcdir)/regress/core/computed_columns -+ $(top_srcdir)/regress/core/wrapx - - # Slow slow tests - TESTS_SLOW = \ diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk index ac4f8ad..4bad4fc 100644 --- a/regress/loader/tests.mk diff --git a/docker-compose/ext-src/postgis-src/regular-test.sh b/docker-compose/ext-src/postgis-src/regular-test.sh index 4b0b929946..1b1683b3f1 100755 --- a/docker-compose/ext-src/postgis-src/regular-test.sh +++ b/docker-compose/ext-src/postgis-src/regular-test.sh @@ -10,8 +10,8 @@ psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='U -c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \ -c "CREATE EXTENSION postgis_raster SCHEMA public" \ -c "CREATE EXTENSION postgis_sfcgal SCHEMA public" -patch -p1 Date: Thu, 5 Jun 2025 20:53:14 +0200 Subject: [PATCH 02/12] neon_local timeline import: create timelines on safekeepers (#12138) neon_local's timeline import subcommand creates timelines manually, but doesn't create them on the safekeepers. If a test then tries to open an endpoint to read from the timeline, it will error in the new world with `--timelines-onto-safekeepers`. Therefore, if that flag is enabled, create the timelines on the safekeepers. Note that this import functionality is different from the fast import feature (https://github.com/neondatabase/neon/issues/10188, #11801). Part of #11670 As well as part of #11712 --- Cargo.lock | 1 + control_plane/Cargo.toml | 1 + control_plane/src/bin/neon_local.rs | 41 ++++++++++++++++++- control_plane/src/pageserver.rs | 12 ++++++ control_plane/src/safekeeper.rs | 63 ++++++++++++----------------- libs/safekeeper_api/src/models.rs | 2 +- safekeeper/client/src/mgmt_api.rs | 10 ++++- test_runner/regress/test_import.py | 3 ++ 8 files changed, 91 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 588a63b6a3..5f71af118c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1445,6 +1445,7 @@ dependencies = [ "regex", "reqwest", "safekeeper_api", + "safekeeper_client", "scopeguard", "serde", "serde_json", diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 62c039047f..bbaa3f12b9 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -36,6 +36,7 @@ pageserver_api.workspace = true pageserver_client.workspace = true postgres_backend.workspace = true safekeeper_api.workspace = true +safekeeper_client.workspace = true postgres_connection.workspace = true storage_broker.workspace = true http-utils.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index ef6985d697..76e33e4bff 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -45,7 +45,7 @@ use pageserver_api::models::{ use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use postgres_backend::AuthType; use postgres_connection::parse_host_port; -use safekeeper_api::membership::SafekeeperGeneration; +use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, @@ -1255,6 +1255,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re pageserver .timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version) .await?; + if env.storage_controller.timelines_onto_safekeepers { + println!("Creating timeline on safekeeper ..."); + let timeline_info = pageserver + .timeline_info( + TenantShardId::unsharded(tenant_id), + timeline_id, + pageserver_client::mgmt_api::ForceAwaitLogicalSize::No, + ) + .await?; + let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap()); + let default_host = default_sk + .conf + .listen_addr + .clone() + .unwrap_or_else(|| "localhost".to_string()); + let mconf = safekeeper_api::membership::Configuration { + generation: SafekeeperGeneration::new(1), + members: safekeeper_api::membership::MemberSet { + m: vec![SafekeeperId { + host: default_host, + id: default_sk.conf.id, + pg_port: default_sk.conf.pg_port, + }], + }, + new_members: None, + }; + let pg_version = args.pg_version * 10000; + let req = safekeeper_api::models::TimelineCreateRequest { + tenant_id, + timeline_id, + mconf, + pg_version, + system_id: None, + wal_seg_size: None, + start_lsn: timeline_info.last_record_lsn, + commit_lsn: None, + }; + default_sk.create_timeline(&req).await?; + } env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?; println!("Done"); } diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 0cf7ca184d..3b7c4ec39f 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -635,4 +635,16 @@ impl PageServerNode { Ok(()) } + pub async fn timeline_info( + &self, + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + force_await_logical_size: mgmt_api::ForceAwaitLogicalSize, + ) -> anyhow::Result { + let timeline_info = self + .http_client + .timeline_info(tenant_shard_id, timeline_id, force_await_logical_size) + .await?; + Ok(timeline_info) + } } diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index eec2c997e6..28d369a315 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -6,7 +6,6 @@ //! .neon/safekeepers/ //! ``` use std::error::Error as _; -use std::future::Future; use std::io::Write; use std::path::PathBuf; use std::time::Duration; @@ -14,9 +13,9 @@ use std::{io, result}; use anyhow::Context; use camino::Utf8PathBuf; -use http_utils::error::HttpErrorBody; use postgres_connection::PgConnectionConfig; -use reqwest::{IntoUrl, Method}; +use safekeeper_api::models::TimelineCreateRequest; +use safekeeper_client::mgmt_api; use thiserror::Error; use utils::auth::{Claims, Scope}; use utils::id::NodeId; @@ -35,25 +34,14 @@ pub enum SafekeeperHttpError { type Result = result::Result; -pub(crate) trait ResponseErrorMessageExt: Sized { - fn error_from_body(self) -> impl Future> + Send; -} - -impl ResponseErrorMessageExt for reqwest::Response { - async fn error_from_body(self) -> Result { - let status = self.status(); - if !(status.is_client_error() || status.is_server_error()) { - return Ok(self); - } - - // reqwest does not export its error construction utility functions, so let's craft the message ourselves - let url = self.url().to_owned(); - Err(SafekeeperHttpError::Response( - match self.json::().await { - Ok(err_body) => format!("Error: {}", err_body.msg), - Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), - }, - )) +fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError { + use mgmt_api::Error::*; + match err { + ApiError(_, str) => SafekeeperHttpError::Response(str), + Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()), + ReceiveBody(err) => SafekeeperHttpError::Transport(err), + ReceiveErrorBody(err) => SafekeeperHttpError::Response(err), + Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")), } } @@ -70,9 +58,8 @@ pub struct SafekeeperNode { pub pg_connection_config: PgConnectionConfig, pub env: LocalEnv, - pub http_client: reqwest::Client, + pub http_client: mgmt_api::Client, pub listen_addr: String, - pub http_base_url: String, } impl SafekeeperNode { @@ -82,13 +69,14 @@ impl SafekeeperNode { } else { "127.0.0.1".to_string() }; + let jwt = None; + let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port); SafekeeperNode { id: conf.id, conf: conf.clone(), pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port), env: env.clone(), - http_client: env.create_http_client(), - http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port), + http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt), listen_addr, } } @@ -278,20 +266,19 @@ impl SafekeeperNode { ) } - fn http_request(&self, method: Method, url: U) -> reqwest::RequestBuilder { - // TODO: authentication - //if self.env.auth_type == AuthType::NeonJWT { - // builder = builder.bearer_auth(&self.env.safekeeper_auth_token) - //} - self.http_client.request(method, url) + pub async fn check_status(&self) -> Result<()> { + self.http_client + .status() + .await + .map_err(err_from_client_err)?; + Ok(()) } - pub async fn check_status(&self) -> Result<()> { - self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status")) - .send() - .await? - .error_from_body() - .await?; + pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> { + self.http_client + .create_timeline(req) + .await + .map_err(err_from_client_err)?; Ok(()) } } diff --git a/libs/safekeeper_api/src/models.rs b/libs/safekeeper_api/src/models.rs index 8658dc4011..fd05f6fda3 100644 --- a/libs/safekeeper_api/src/models.rs +++ b/libs/safekeeper_api/src/models.rs @@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback; use crate::membership::Configuration; use crate::{ServerInfo, Term}; -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct SafekeeperStatus { pub id: NodeId, } diff --git a/safekeeper/client/src/mgmt_api.rs b/safekeeper/client/src/mgmt_api.rs index b364ac8e48..2e46a7b529 100644 --- a/safekeeper/client/src/mgmt_api.rs +++ b/safekeeper/client/src/mgmt_api.rs @@ -8,8 +8,8 @@ use std::error::Error as _; use http_utils::error::HttpErrorBody; use reqwest::{IntoUrl, Method, StatusCode}; use safekeeper_api::models::{ - self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest, - TimelineStatus, + self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization, + TimelineCreateRequest, TimelineStatus, }; use utils::id::{NodeId, TenantId, TimelineId}; use utils::logging::SecretString; @@ -183,6 +183,12 @@ impl Client { self.get(&uri).await } + pub async fn status(&self) -> Result { + let uri = format!("{}/v1/status", self.mgmt_api_endpoint); + let resp = self.get(&uri).await?; + resp.json().await.map_err(Error::ReceiveBody) + } + pub async fn utilization(&self) -> Result { let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint); let resp = self.get(&uri).await?; diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 55737c35f0..e1070a81e6 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -87,6 +87,9 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build # Set up pageserver for import neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + } env = neon_env_builder.init_start() env.pageserver.tenant_create(tenant) From b23e75ebfe6b6991cdff94d7ce1f627997db797e Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Fri, 6 Jun 2025 14:50:54 +0800 Subject: [PATCH 03/12] test(pageserver): ensure offload cleans up metrics (#12127) Add a test to ensure timeline metrics are fully cleaned up after offloading. Signed-off-by: Alex Chi Z --- test_runner/regress/test_tenants.py | 61 ++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index d08692500f..c54dd8b38d 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING import pytest import requests -from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.metrics import ( PAGESERVER_GLOBAL_METRICS, @@ -299,6 +299,65 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert post_detach_samples == set() +def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder): + """Tests that when a timeline is offloaded, the tenant specific metrics are not left behind""" + + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) + + neon_env_builder.num_safekeepers = 3 + + env = neon_env_builder.init_start() + tenant_1, _ = env.create_tenant() + + timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1) + timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1) + + endpoint_tenant1 = env.endpoints.create_start( + "test_metrics_removed_after_offload_1", tenant_id=tenant_1 + ) + endpoint_tenant2 = env.endpoints.create_start( + "test_metrics_removed_after_offload_2", tenant_id=tenant_1 + ) + + for endpoint in [endpoint_tenant1, endpoint_tenant2]: + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CREATE TABLE t(key int primary key, value text)") + cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'") + cur.execute("SELECT sum(key) FROM t") + assert cur.fetchone() == (5000050000,) + endpoint.stop() + + def get_ps_metric_samples_for_timeline( + tenant_id: TenantId, timeline_id: TimelineId + ) -> list[Sample]: + ps_metrics = env.pageserver.http_client().get_metrics() + samples = [] + for metric_name in ps_metrics.metrics: + for sample in ps_metrics.query_all( + name=metric_name, + filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)}, + ): + samples.append(sample) + return samples + + for timeline in [timeline_1, timeline_2]: + pre_offload_samples = set( + [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] + ) + assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}" + env.pageserver.http_client().timeline_archival_config( + tenant_1, + timeline, + state=TimelineArchivalState.ARCHIVED, + ) + env.pageserver.http_client().timeline_offload(tenant_1, timeline) + post_offload_samples = set( + [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] + ) + assert post_offload_samples == set() + + def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() From fe31baf9859d46f1eb9bf884fe480a1755ab01e9 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:38:58 +0800 Subject: [PATCH 04/12] feat(build): add aws cli into the docker image (#12161) ## Problem Makes it easier to debug AWS permission issues (i.e., storage scrubber) ## Summary of changes Install awscliv2 into the docker image. Signed-off-by: Alex Chi Z --- Dockerfile | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Dockerfile b/Dockerfile index 3b7962dcf9..0b7ef491fd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -110,6 +110,19 @@ RUN set -e \ # System postgres for use with client libraries (e.g. in storage controller) postgresql-15 \ openssl \ + unzip \ + curl \ + && ARCH=$(uname -m) \ + && if [ "$ARCH" = "x86_64" ]; then \ + curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \ + elif [ "$ARCH" = "aarch64" ]; then \ + curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \ + else \ + echo "Unsupported architecture: $ARCH" && exit 1; \ + fi \ + && unzip awscliv2.zip \ + && ./aws/install \ + && rm -rf aws awscliv2.zip \ && rm -f /etc/apt/apt.conf.d/80-retries \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ && useradd -d /data neon \ From c511786548c8f09048b09a33b0e560fe2e518a5f Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Fri, 6 Jun 2025 12:01:58 +0200 Subject: [PATCH 05/12] pageserver: move `spawn_grpc` to `GrpcPageServiceHandler::spawn` (#12147) Mechanical move, no logic changes. --- pageserver/src/bin/pageserver.rs | 3 +- pageserver/src/page_service.rs | 188 ++++++++++++++++--------------- 2 files changed, 97 insertions(+), 94 deletions(-) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index a1a95ad2d1..5cd865f53e 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; use pageserver::feature_resolver::FeatureResolver; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; +use pageserver::page_service::GrpcPageServiceHandler; use pageserver::task_mgr::{ BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME, }; @@ -814,7 +815,7 @@ fn start_pageserver( // necessary? let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { - page_service_grpc = Some(page_service::spawn_grpc( + page_service_grpc = Some(GrpcPageServiceHandler::spawn( tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 4a1ddf09b5..d47f6bd095 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -169,99 +169,6 @@ pub fn spawn( Listener { cancel, task } } -/// Spawns a gRPC server for the page service. -/// -/// TODO: move this onto GrpcPageServiceHandler::spawn(). -/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we -/// need to reimplement the TCP+TLS accept loop ourselves. -pub fn spawn_grpc( - tenant_manager: Arc, - auth: Option>, - perf_trace_dispatch: Option, - get_vectored_concurrent_io: GetVectoredConcurrentIo, - listener: std::net::TcpListener, -) -> anyhow::Result { - let cancel = CancellationToken::new(); - let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) - .download_behavior(DownloadBehavior::Download) - .perf_span_dispatch(perf_trace_dispatch) - .detached_child(); - let gate = Gate::default(); - - // Set up the TCP socket. We take a preconfigured TcpListener to bind the - // port early during startup. - let incoming = { - let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std - listener.set_nonblocking(true)?; - tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?) - .with_nodelay(Some(GRPC_TCP_NODELAY)) - .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) - }; - - // Set up the gRPC server. - // - // TODO: consider tuning window sizes. - let mut server = tonic::transport::Server::builder() - .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) - .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) - .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); - - // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: - // - // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. - // - // * Layers: allow async code, can run code after the service response. However, only has access - // to the raw HTTP request/response, not the gRPC types. - let page_service_handler = GrpcPageServiceHandler { - tenant_manager, - ctx, - gate_guard: gate.enter().expect("gate was just created"), - get_vectored_concurrent_io, - }; - - let observability_layer = ObservabilityLayer; - let mut tenant_interceptor = TenantMetadataInterceptor; - let mut auth_interceptor = TenantAuthInterceptor::new(auth); - - let page_service = tower::ServiceBuilder::new() - // Create tracing span and record request start time. - .layer(observability_layer) - // Intercept gRPC requests. - .layer(tonic::service::InterceptorLayer::new(move |mut req| { - // Extract tenant metadata. - req = tenant_interceptor.call(req)?; - // Authenticate tenant JWT token. - req = auth_interceptor.call(req)?; - Ok(req) - })) - .service(proto::PageServiceServer::new(page_service_handler)); - let server = server.add_service(page_service); - - // Reflection service for use with e.g. grpcurl. - let reflection_service = tonic_reflection::server::Builder::configure() - .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) - .build_v1()?; - let server = server.add_service(reflection_service); - - // Spawn server task. - let task_cancel = cancel.clone(); - let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( - "grpc listener", - async move { - let result = server - .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) - .await; - if result.is_ok() { - // TODO: revisit shutdown logic once page service is implemented. - gate.close().await; - } - result - }, - )); - - Ok(CancellableTask { task, cancel }) -} - impl Listener { pub async fn stop_accepting(self) -> Connections { self.cancel.cancel(); @@ -3366,6 +3273,101 @@ pub struct GrpcPageServiceHandler { } impl GrpcPageServiceHandler { + /// Spawns a gRPC server for the page service. + /// + /// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we + /// need to reimplement the TCP+TLS accept loop ourselves. + pub fn spawn( + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + get_vectored_concurrent_io: GetVectoredConcurrentIo, + listener: std::net::TcpListener, + ) -> anyhow::Result { + let cancel = CancellationToken::new(); + let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch) + .detached_child(); + let gate = Gate::default(); + + // Set up the TCP socket. We take a preconfigured TcpListener to bind the + // port early during startup. + let incoming = { + let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std + listener.set_nonblocking(true)?; + tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std( + listener, + )?) + .with_nodelay(Some(GRPC_TCP_NODELAY)) + .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) + }; + + // Set up the gRPC server. + // + // TODO: consider tuning window sizes. + let mut server = tonic::transport::Server::builder() + .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) + .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) + .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); + + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. + let page_service_handler = GrpcPageServiceHandler { + tenant_manager, + ctx, + gate_guard: gate.enter().expect("gate was just created"), + get_vectored_concurrent_io, + }; + + let observability_layer = ObservabilityLayer; + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + + let page_service = tower::ServiceBuilder::new() + // Create tracing span and record request start time. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + // Run the page service. + .service(proto::PageServiceServer::new(page_service_handler)); + let server = server.add_service(page_service); + + // Reflection service for use with e.g. grpcurl. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1()?; + let server = server.add_service(reflection_service); + + // Spawn server task. + let task_cancel = cancel.clone(); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "grpc listener", + async move { + let result = server + .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) + .await; + if result.is_ok() { + // TODO: revisit shutdown logic once page service is implemented. + gate.close().await; + } + result + }, + )); + + Ok(CancellableTask { task, cancel }) + } + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of /// relations and their sizes, as well as SLRU segments and similar data. #[allow(clippy::result_large_err)] From 590301df08b4eb3e8afc7afa7e3a91b6ab5dc420 Mon Sep 17 00:00:00 2001 From: Alexander Sarantcev <99037063+ephemeralsad@users.noreply.github.com> Date: Fri, 6 Jun 2025 14:16:55 +0400 Subject: [PATCH 06/12] storcon: Introduce deletion tombstones to support flaky node scenario (#12096) ## Problem Removed nodes can re-add themselves on restart if not properly tombstoned. We need a mechanism (e.g. soft-delete flag) to prevent this, especially in cases where the node is unreachable. More details there: #12036 ## Summary of changes - Introduced `NodeLifecycle` enum to represent node lifecycle states. - Added a string representation of `NodeLifecycle` to the `nodes` table. - Implemented node removal using a tombstone mechanism. - Introduced `/debug/v1/tombstone*` handlers to manage the tombstone state. --- control_plane/storcon_cli/src/main.rs | 41 +++++++ libs/pageserver_api/src/controller_api.rs | 29 +++++ .../down.sql | 1 + .../up.sql | 1 + storage_controller/src/http.rs | 50 ++++++++ storage_controller/src/node.rs | 6 +- storage_controller/src/persistence.rs | 115 +++++++++++++++--- storage_controller/src/schema.rs | 1 + storage_controller/src/service.rs | 57 ++++++++- test_runner/fixtures/neon_fixtures.py | 16 +++ .../regress/test_storage_controller.py | 52 ++++++++ 11 files changed, 345 insertions(+), 24 deletions(-) create mode 100644 storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql create mode 100644 storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index 19c686dcfd..1a9e944e07 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -61,10 +61,16 @@ enum Command { #[arg(long)] scheduling: Option, }, + // Set a node status as deleted. NodeDelete { #[arg(long)] node_id: NodeId, }, + /// Delete a tombstone of node from the storage controller. + NodeDeleteTombstone { + #[arg(long)] + node_id: NodeId, + }, /// Modify a tenant's policies in the storage controller TenantPolicy { #[arg(long)] @@ -82,6 +88,8 @@ enum Command { }, /// List nodes known to the storage controller Nodes {}, + /// List soft deleted nodes known to the storage controller + NodeTombstones {}, /// List tenants known to the storage controller Tenants { /// If this field is set, it will list the tenants on a specific node @@ -900,6 +908,39 @@ async fn main() -> anyhow::Result<()> { .dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None) .await?; } + Command::NodeDeleteTombstone { node_id } => { + storcon_client + .dispatch::<(), ()>( + Method::DELETE, + format!("debug/v1/tombstone/{node_id}"), + None, + ) + .await?; + } + Command::NodeTombstones {} => { + let mut resp = storcon_client + .dispatch::<(), Vec>( + Method::GET, + "debug/v1/tombstone".to_string(), + None, + ) + .await?; + + resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr)); + + let mut table = comfy_table::Table::new(); + table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]); + for node in resp { + table.add_row([ + format!("{}", node.id), + node.listen_http_addr, + node.availability_zone_id, + format!("{:?}", node.scheduling), + format!("{:?}", node.availability), + ]); + } + println!("{table}"); + } Command::TenantSetTimeBasedEviction { tenant_id, period, diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index c5b49edba0..ae792cc81c 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -344,6 +344,35 @@ impl Default for ShardSchedulingPolicy { } } +#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)] +pub enum NodeLifecycle { + Active, + Deleted, +} + +impl FromStr for NodeLifecycle { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "active" => Ok(Self::Active), + "deleted" => Ok(Self::Deleted), + _ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")), + } + } +} + +impl From for String { + fn from(value: NodeLifecycle) -> String { + use NodeLifecycle::*; + match value { + Active => "active", + Deleted => "deleted", + } + .to_string() + } +} + #[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)] pub enum NodeSchedulingPolicy { Active, diff --git a/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql new file mode 100644 index 0000000000..a09acb916b --- /dev/null +++ b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql @@ -0,0 +1 @@ +ALTER TABLE nodes DROP COLUMN lifecycle; diff --git a/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql new file mode 100644 index 0000000000..e03a0cadba --- /dev/null +++ b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql @@ -0,0 +1 @@ +ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active'; diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 2b1c0db12f..705b81077e 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -907,6 +907,42 @@ async fn handle_node_delete(req: Request) -> Result, ApiErr json_response(StatusCode::OK, state.service.node_delete(node_id).await?) } +async fn handle_tombstone_list(req: Request) -> Result, ApiError> { + check_permissions(&req, Scope::Admin)?; + + let req = match maybe_forward(req).await { + ForwardOutcome::Forwarded(res) => { + return res; + } + ForwardOutcome::NotForwarded(req) => req, + }; + + let state = get_state(&req); + let mut nodes = state.service.tombstone_list().await?; + nodes.sort_by_key(|n| n.get_id()); + let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::>(); + + json_response(StatusCode::OK, api_nodes) +} + +async fn handle_tombstone_delete(req: Request) -> Result, ApiError> { + check_permissions(&req, Scope::Admin)?; + + let req = match maybe_forward(req).await { + ForwardOutcome::Forwarded(res) => { + return res; + } + ForwardOutcome::NotForwarded(req) => req, + }; + + let state = get_state(&req); + let node_id: NodeId = parse_request_param(&req, "node_id")?; + json_response( + StatusCode::OK, + state.service.tombstone_delete(node_id).await?, + ) +} + async fn handle_node_configure(req: Request) -> Result, ApiError> { check_permissions(&req, Scope::Admin)?; @@ -2062,6 +2098,20 @@ pub fn make_router( .post("/debug/v1/node/:node_id/drop", |r| { named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop")) }) + .delete("/debug/v1/tombstone/:node_id", |r| { + named_request_span( + r, + handle_tombstone_delete, + RequestName("debug_v1_tombstone_delete"), + ) + }) + .get("/debug/v1/tombstone", |r| { + named_request_span( + r, + handle_tombstone_list, + RequestName("debug_v1_tombstone_list"), + ) + }) .post("/debug/v1/tenant/:tenant_id/import", |r| { named_request_span( r, diff --git a/storage_controller/src/node.rs b/storage_controller/src/node.rs index e180c49b43..8e0f1873e5 100644 --- a/storage_controller/src/node.rs +++ b/storage_controller/src/node.rs @@ -2,7 +2,7 @@ use std::str::FromStr; use std::time::Duration; use pageserver_api::controller_api::{ - AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest, + AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest, NodeSchedulingPolicy, TenantLocateResponseShard, }; use pageserver_api::shard::TenantShardId; @@ -29,6 +29,7 @@ pub(crate) struct Node { availability: NodeAvailability, scheduling: NodeSchedulingPolicy, + lifecycle: NodeLifecycle, listen_http_addr: String, listen_http_port: u16, @@ -228,6 +229,7 @@ impl Node { listen_pg_addr, listen_pg_port, scheduling: NodeSchedulingPolicy::Active, + lifecycle: NodeLifecycle::Active, availability: NodeAvailability::Offline, availability_zone_id, use_https, @@ -239,6 +241,7 @@ impl Node { NodePersistence { node_id: self.id.0 as i64, scheduling_policy: self.scheduling.into(), + lifecycle: self.lifecycle.into(), listen_http_addr: self.listen_http_addr.clone(), listen_http_port: self.listen_http_port as i32, listen_https_port: self.listen_https_port.map(|x| x as i32), @@ -263,6 +266,7 @@ impl Node { availability: NodeAvailability::Offline, scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy) .expect("Bad scheduling policy in DB"), + lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"), listen_http_addr: np.listen_http_addr, listen_http_port: np.listen_http_port as u16, listen_https_port: np.listen_https_port.map(|x| x as u16), diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index 052c0f02eb..2edfe3a338 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -19,7 +19,7 @@ use futures::FutureExt; use futures::future::BoxFuture; use itertools::Itertools; use pageserver_api::controller_api::{ - AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy, + AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy, SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy, }; use pageserver_api::models::{ShardImportStatus, TenantConfig}; @@ -102,6 +102,7 @@ pub(crate) enum DatabaseOperation { UpdateNode, DeleteNode, ListNodes, + ListTombstones, BeginShardSplit, CompleteShardSplit, AbortShardSplit, @@ -357,6 +358,8 @@ impl Persistence { } /// When a node is first registered, persist it before using it for anything + /// If the provided node_id already exists, it will be error. + /// The common case is when a node marked for deletion wants to register. pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> { let np = &node.to_persistent(); self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| { @@ -373,19 +376,41 @@ impl Persistence { /// At startup, populate the list of nodes which our shards may be placed on pub(crate) async fn list_nodes(&self) -> DatabaseResult> { - let nodes: Vec = self + use crate::schema::nodes::dsl::*; + + let result: Vec = self .with_measured_conn(DatabaseOperation::ListNodes, move |conn| { Box::pin(async move { Ok(crate::schema::nodes::table + .filter(lifecycle.ne(String::from(NodeLifecycle::Deleted))) .load::(conn) .await?) }) }) .await?; - tracing::info!("list_nodes: loaded {} nodes", nodes.len()); + tracing::info!("list_nodes: loaded {} nodes", result.len()); - Ok(nodes) + Ok(result) + } + + pub(crate) async fn list_tombstones(&self) -> DatabaseResult> { + use crate::schema::nodes::dsl::*; + + let result: Vec = self + .with_measured_conn(DatabaseOperation::ListTombstones, move |conn| { + Box::pin(async move { + Ok(crate::schema::nodes::table + .filter(lifecycle.eq(String::from(NodeLifecycle::Deleted))) + .load::(conn) + .await?) + }) + }) + .await?; + + tracing::info!("list_tombstones: loaded {} nodes", result.len()); + + Ok(result) } pub(crate) async fn update_node( @@ -404,6 +429,7 @@ impl Persistence { Box::pin(async move { let updated = diesel::update(nodes) .filter(node_id.eq(input_node_id.0 as i64)) + .filter(lifecycle.ne(String::from(NodeLifecycle::Deleted))) .set(values) .execute(conn) .await?; @@ -447,6 +473,57 @@ impl Persistence { .await } + /// Tombstone is a special state where the node is not deleted from the database, + /// but it is not available for usage. + /// The main reason for it is to prevent the flaky node to register. + pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + self.update_node( + del_node_id, + lifecycle.eq(String::from(NodeLifecycle::Deleted)), + ) + .await + } + + pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| { + Box::pin(async move { + // You can hard delete a node only if it has a tombstone. + // So we need to check if the node has lifecycle set to deleted. + let node_to_delete = nodes + .filter(node_id.eq(del_node_id.0 as i64)) + .first::(conn) + .await + .optional()?; + + if let Some(np) = node_to_delete { + let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| { + DatabaseError::Logical(format!( + "Node {} has invalid lifecycle: {}", + del_node_id, e + )) + })?; + + if lc != NodeLifecycle::Deleted { + return Err(DatabaseError::Logical(format!( + "Node {} was not soft deleted before, cannot hard delete it", + del_node_id + ))); + } + + diesel::delete(nodes) + .filter(node_id.eq(del_node_id.0 as i64)) + .execute(conn) + .await?; + } + + Ok(()) + }) + }) + .await + } + /// At startup, load the high level state for shards, such as their config + policy. This will /// be enriched at runtime with state discovered on pageservers. /// @@ -543,21 +620,6 @@ impl Persistence { .await } - pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> { - use crate::schema::nodes::dsl::*; - self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| { - Box::pin(async move { - diesel::delete(nodes) - .filter(node_id.eq(del_node_id.0 as i64)) - .execute(conn) - .await?; - - Ok(()) - }) - }) - .await - } - /// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient /// batched increment of the generations of all tenants whose generation_pageserver is equal to /// the node that called /re-attach. @@ -571,6 +633,20 @@ impl Persistence { let updated = self .with_measured_conn(DatabaseOperation::ReAttach, move |conn| { Box::pin(async move { + // Check if the node is not marked as deleted + let deleted_node: i64 = nodes + .filter(node_id.eq(input_node_id.0 as i64)) + .filter(lifecycle.eq(String::from(NodeLifecycle::Deleted))) + .count() + .get_result(conn) + .await?; + if deleted_node > 0 { + return Err(DatabaseError::Logical(format!( + "Node {} is marked as deleted, re-attach is not allowed", + input_node_id + ))); + } + let rows_updated = diesel::update(tenant_shards) .filter(generation_pageserver.eq(input_node_id.0 as i64)) .set(generation.eq(generation + 1)) @@ -2048,6 +2124,7 @@ pub(crate) struct NodePersistence { pub(crate) listen_pg_port: i32, pub(crate) availability_zone_id: String, pub(crate) listen_https_port: Option, + pub(crate) lifecycle: String, } /// Tenant metadata health status that are stored durably. diff --git a/storage_controller/src/schema.rs b/storage_controller/src/schema.rs index 20be9bb5ca..f5807cfcd2 100644 --- a/storage_controller/src/schema.rs +++ b/storage_controller/src/schema.rs @@ -33,6 +33,7 @@ diesel::table! { listen_pg_port -> Int4, availability_zone_id -> Varchar, listen_https_port -> Nullable, + lifecycle -> Varchar, } } diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 790797bae2..cb29993e8c 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -166,6 +166,7 @@ enum NodeOperations { Register, Configure, Delete, + DeleteTombstone, } /// The leadership status for the storage controller process. @@ -6909,7 +6910,7 @@ impl Service { /// detaching or deleting it on pageservers. We do not try and re-schedule any /// tenants that were on this node. pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> { - self.persistence.delete_node(node_id).await?; + self.persistence.set_tombstone(node_id).await?; let mut locked = self.inner.write().unwrap(); @@ -7033,9 +7034,10 @@ impl Service { // That is safe because in Service::spawn we only use generation_pageserver if it refers to a node // that exists. - // 2. Actually delete the node from the database and from in-memory state + // 2. Actually delete the node from in-memory state and set tombstone to the database + // for preventing the node to register again. tracing::info!("Deleting node from database"); - self.persistence.delete_node(node_id).await?; + self.persistence.set_tombstone(node_id).await?; Ok(()) } @@ -7054,6 +7056,35 @@ impl Service { Ok(nodes) } + pub(crate) async fn tombstone_list(&self) -> Result, ApiError> { + self.persistence + .list_tombstones() + .await? + .into_iter() + .map(|np| Node::from_persistent(np, false)) + .collect::, _>>() + .map_err(ApiError::InternalServerError) + } + + pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> { + let _node_lock = trace_exclusive_lock( + &self.node_op_locks, + node_id, + NodeOperations::DeleteTombstone, + ) + .await; + + if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) { + self.persistence.delete_node(node_id).await?; + Ok(()) + } else { + Err(ApiError::Conflict(format!( + "Node {} is in use, consider using tombstone API first", + node_id + ))) + } + } + pub(crate) async fn get_node(&self, node_id: NodeId) -> Result { self.inner .read() @@ -7224,7 +7255,25 @@ impl Service { }; match registration_status { - RegistrationStatus::New => self.persistence.insert_node(&new_node).await?, + RegistrationStatus::New => { + self.persistence.insert_node(&new_node).await.map_err(|e| { + if matches!( + e, + crate::persistence::DatabaseError::Query( + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + ) + ) + ) { + // The node can be deleted by tombstone API, and not show up in the list of nodes. + // If you see this error, check tombstones first. + ApiError::Conflict(format!("Node {} is already exists", new_node.get_id())) + } else { + ApiError::from(e) + } + })?; + } RegistrationStatus::NeedUpdate => { self.persistence .update_node_on_registration( diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index db3f080261..5223e34baf 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2054,6 +2054,14 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) + def tombstone_delete(self, node_id): + log.info(f"tombstone_delete({node_id})") + self.request( + "DELETE", + f"{self.api}/debug/v1/tombstone/{node_id}", + headers=self.headers(TokenScope.ADMIN), + ) + def node_drain(self, node_id): log.info(f"node_drain({node_id})") self.request( @@ -2110,6 +2118,14 @@ class NeonStorageController(MetricsGetter, LogUtils): ) return response.json() + def tombstone_list(self): + response = self.request( + "GET", + f"{self.api}/debug/v1/tombstone", + headers=self.headers(TokenScope.ADMIN), + ) + return response.json() + def tenant_shard_dump(self): """ Debug listing API: dumps the internal map of tenant shards diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 346ef0951d..5e0dd780c3 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -3093,6 +3093,58 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB wait_until(reconfigure_node_again) +def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder): + neon_env_builder.num_pageservers = 3 + + env = neon_env_builder.init_start() + + def assert_nodes_count(n: int): + nodes = env.storage_controller.node_list() + assert len(nodes) == n + + # Nodes count must remain the same before deletion + assert_nodes_count(3) + + ps = env.pageservers[0] + env.storage_controller.node_delete(ps.id) + + # After deletion, the node count must be reduced + assert_nodes_count(2) + + # Running pageserver CLI init in a separate thread + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + log.info("Restarting tombstoned pageserver...") + ps.stop() + ps_start_fut = executor.submit(lambda: ps.start(await_active=False)) + + # After deleted pageserver restart, the node count must remain the same + assert_nodes_count(2) + + tombstones = env.storage_controller.tombstone_list() + assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id + + env.storage_controller.tombstone_delete(ps.id) + + tombstones = env.storage_controller.tombstone_list() + assert len(tombstones) == 0 + + # Wait for the pageserver start operation to complete. + # If it fails with an exception, we try restarting the pageserver since the failure + # may be due to the storage controller refusing to register the node. + # However, if we get a TimeoutError that means the pageserver is completely hung, + # which is an unexpected failure mode that we'll let propagate up. + try: + ps_start_fut.result(timeout=20) + except TimeoutError: + raise + except Exception: + log.info("Restarting deleted pageserver...") + ps.restart() + + # Finally, the node can be registered again after tombstone is deleted + wait_until(lambda: assert_nodes_count(3)) + + def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder): """ The storage controller is meant to handle the case where a timeline CRUD operation races From 4d99b6ff4d1e5ab87f198421bae8bab3948c6b66 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 6 Jun 2025 11:29:55 +0100 Subject: [PATCH 07/12] [proxy] separate compute connect from compute authentication (#12145) ## Problem PGLB/Neonkeeper needs to separate the concerns of connecting to compute, and authenticating to compute. Additionally, the code within `connect_to_compute` is rather messy, spending effort on recovering the authentication info after wake_compute. ## Summary of changes Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only returns `ConnectInfo` and `AuthInfo` is determined separately from the `handshake`/`authenticate` process. Additionally, `ConnectInfo::connect_raw` is in-charge or establishing the TLS connection, and the `postgres_client::Config::connect_raw` is configured to use `NoTls` which will force it to skip the TLS negotiation. This should just work. --- .../proxy/tokio-postgres2/src/cancel_query.rs | 2 +- libs/proxy/tokio-postgres2/src/config.rs | 3 +- libs/proxy/tokio-postgres2/src/connect.rs | 2 +- libs/proxy/tokio-postgres2/src/tls.rs | 4 +- proxy/src/auth/backend/classic.rs | 5 - proxy/src/auth/backend/console_redirect.rs | 59 +++-- proxy/src/auth/backend/local.rs | 10 +- proxy/src/auth/backend/mod.rs | 9 - proxy/src/auth/flow.rs | 7 - proxy/src/cancellation.rs | 7 +- proxy/src/{compute.rs => compute/mod.rs} | 202 ++++++++++-------- proxy/src/compute/tls.rs | 63 ++++++ proxy/src/console_redirect_proxy.rs | 6 +- .../control_plane/client/cplane_proxy_v1.rs | 24 +-- proxy/src/control_plane/client/mock.rs | 46 ++-- proxy/src/control_plane/mod.rs | 30 +-- proxy/src/pglb/connect_compute.rs | 47 ++-- proxy/src/pqproto.rs | 77 ++++--- proxy/src/proxy/mod.rs | 16 +- proxy/src/proxy/retry.rs | 4 +- proxy/src/proxy/tests/mod.rs | 26 ++- proxy/src/serverless/backend.rs | 39 ++-- proxy/src/serverless/conn_pool.rs | 4 +- proxy/src/tls/postgres_rustls.rs | 46 ++-- 24 files changed, 382 insertions(+), 356 deletions(-) rename proxy/src/{compute.rs => compute/mod.rs} (68%) create mode 100644 proxy/src/compute/tls.rs diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs index 0bdad0b554..4c2a5ef50f 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_query.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket}; pub(crate) async fn cancel_query( config: Option, ssl_mode: SslMode, - mut tls: T, + tls: T, process_id: i32, secret_key: i32, ) -> Result<(), Error> diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 978d348741..243a5bc725 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -17,7 +17,6 @@ use crate::{Client, Connection, Error}; /// TLS configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[non_exhaustive] pub enum SslMode { /// Do not use TLS. Disable, @@ -231,7 +230,7 @@ impl Config { /// Requires the `runtime` Cargo feature (enabled by default). pub async fn connect( &self, - tls: T, + tls: &T, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 39a0a87c74..f7bc863337 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, RawConnection}; pub async fn connect( - mut tls: T, + tls: &T, config: &Config, ) -> Result<(Client, Connection), Error> where diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index 41b51368ff..f9cbcf4991 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -47,7 +47,7 @@ pub trait MakeTlsConnect { /// Creates a new `TlsConnect`or. /// /// The domain name is provided for certificate verification and SNI. - fn make_tls_connect(&mut self, domain: &str) -> Result; + fn make_tls_connect(&self, domain: &str) -> Result; } /// An asynchronous function wrapping a stream in a TLS session. @@ -85,7 +85,7 @@ impl MakeTlsConnect for NoTls { type TlsConnect = NoTls; type Error = NoTlsError; - fn make_tls_connect(&mut self, _: &str) -> Result { + fn make_tls_connect(&self, _: &str) -> Result { Ok(NoTls) } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 8445368740..f35b3ecc05 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -18,11 +18,6 @@ pub(super) async fn authenticate( secret: AuthSecret, ) -> auth::Result { let scram_keys = match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::MalformedPassword("MD5 not supported")); - } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index c388848926..455d96c90a 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -6,10 +6,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; -use super::ComputeCredentialKeys; -use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; @@ -98,15 +97,11 @@ impl ConsoleRedirectBackend { ctx: &RequestContext, auth_config: &'static AuthenticationConfig, client: &mut PqStream, - ) -> auth::Result<( - ConsoleRedirectNodeInfo, - ComputeUserInfo, - Option>, - )> { + ) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> { authenticate(ctx, auth_config, &self.console_uri, client) .await - .map(|(node_info, user_info, ip_allowlist)| { - (ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist) + .map(|(node_info, auth_info, user_info)| { + (ConsoleRedirectNodeInfo(node_info), auth_info, user_info) }) } } @@ -121,10 +116,6 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo { ) -> Result { Ok(Cached::new_uncached(self.0.clone())) } - - fn get_keys(&self) -> &ComputeCredentialKeys { - &ComputeCredentialKeys::None - } } async fn authenticate( @@ -132,7 +123,7 @@ async fn authenticate( auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result<(NodeInfo, ComputeUserInfo, Option>)> { +) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> { ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect); // registering waiter can fail if we get unlucky with rng. @@ -192,10 +183,24 @@ async fn authenticate( client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // This config should be self-contained, because we won't - // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port); - config.dbname(&db_info.dbname).user(&db_info.user); + // Backwards compatibility. pg_sni_proxy uses "--" in domain names + // while direct connections do not. Once we migrate to pg_sni_proxy + // everywhere, we can remove this. + let ssl_mode = if db_info.host.contains("--") { + // we need TLS connection with SNI info to properly route it + SslMode::Require + } else { + SslMode::Disable + }; + + let conn_info = compute::ConnectInfo { + host: db_info.host.into(), + port: db_info.port, + ssl_mode, + host_addr: None, + }; + let auth_info = + AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref()); let user: RoleName = db_info.user.into(); let user_info = ComputeUserInfo { @@ -209,26 +214,12 @@ async fn authenticate( ctx.set_project(db_info.aux.clone()); info!("woken up a compute node"); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names - // while direct connections do not. Once we migrate to pg_sni_proxy - // everywhere, we can remove this. - if db_info.host.contains("--") { - // we need TLS connection with SNI info to properly route it - config.ssl_mode(SslMode::Require); - } else { - config.ssl_mode(SslMode::Disable); - } - - if let Some(password) = db_info.password { - config.password(password.as_ref()); - } - Ok(( NodeInfo { - config, + conn_info, aux: db_info.aux, }, + auth_info, user_info, - db_info.allowed_ips, )) } diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 7a6dceb194..2224f492b8 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -1,11 +1,12 @@ use std::net::SocketAddr; use arc_swap::ArcSwapOption; +use postgres_client::config::SslMode; use tokio::sync::Semaphore; use super::jwt::{AuthRule, FetchAuthRules}; use crate::auth::backend::jwt::FetchAuthRulesError; -use crate::compute::ConnCfg; +use crate::compute::ConnectInfo; use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestContext; use crate::control_plane::NodeInfo; @@ -29,7 +30,12 @@ impl LocalBackend { api: http::Endpoint::new(compute_ctl, http::new_client()), }, node_info: NodeInfo { - config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()), + conn_info: ConnectInfo { + host_addr: Some(postgres_addr.ip()), + host: postgres_addr.ip().to_string().into(), + port: postgres_addr.port(), + ssl_mode: SslMode::Disable, + }, // TODO(conrad): make this better reflect compute info rather than endpoint info. aux: MetricsAuxInfo { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index f978f655c4..edc1ae06d9 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -168,8 +168,6 @@ impl ComputeUserInfo { #[cfg_attr(test, derive(Debug))] pub(crate) enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] - Password(Vec), AuthKeys(AuthKeys), JwtPayload(Vec), None, @@ -419,13 +417,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - - fn get_keys(&self) -> &ComputeCredentialKeys { - match self { - Self::ControlPlane(_, creds) => &creds.keys, - Self::Local(_) => &ComputeCredentialKeys::None, - } - } } #[cfg(test)] diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 8fbc4577e9..c825d5bf4b 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange( secret: AuthSecret, ) -> super::Result> { match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - // test only - Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( - password.to_owned(), - ))) - } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d26641db46..cce4c1d3a0 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; -use crate::tls::postgres_rustls::MakeRustlsConnect; type IpSubnetKey = IpNet; @@ -497,10 +496,8 @@ impl CancelClosure { ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let mut mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, + let tls = <_ as MakeTlsConnect>::make_tls_connect( + compute_config, &self.hostname, ) .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?; diff --git a/proxy/src/compute.rs b/proxy/src/compute/mod.rs similarity index 68% rename from proxy/src/compute.rs rename to proxy/src/compute/mod.rs index 2899f25129..0dacd15547 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute/mod.rs @@ -1,21 +1,24 @@ +mod tls; + use std::fmt::Debug; use std::io; -use std::net::SocketAddr; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; +use postgres_client::config::{AuthKeys, SslMode}; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{CancelToken, RawConnection}; +use postgres_client::{CancelToken, NoTls, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; -use crate::auth::backend::ComputeUserInfo; +use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::parse_endpoint_param; use crate::cancellation::CancelClosure; +use crate::compute::tls::TlsError; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; @@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -38,10 +40,7 @@ pub(crate) enum ConnectionError { Postgres(#[from] postgres_client::Error), #[error("{COULD_NOT_CONNECT}: {0}")] - CouldNotConnect(#[from] io::Error), - - #[error("{COULD_NOT_CONNECT}: {0}")] - TlsError(#[from] InvalidDnsNameError), + TlsError(#[from] TlsError), #[error("{COULD_NOT_CONNECT}: {0}")] WakeComputeError(#[from] WakeComputeError), @@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError { ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } - _ => COULD_NOT_CONNECT.to_owned(), + ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), } } } @@ -85,7 +84,6 @@ impl ReportableError for ConnectionError { crate::error::ErrorKind::Postgres } ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, - ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -96,34 +94,85 @@ impl ReportableError for ConnectionError { /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; -/// A config for establishing a connection to compute node. -/// Eventually, `postgres_client` will be replaced with something better. -/// Newtype allows us to implement methods on top of it. #[derive(Clone)] -pub(crate) struct ConnCfg(Box); +pub enum Auth { + /// Only used during console-redirect. + Password(Vec), + /// Used by sql-over-http, ws, tcp. + Scram(Box), +} + +/// A config for authenticating to the compute node. +pub(crate) struct AuthInfo { + /// None for local-proxy, as we use trust-based localhost auth. + /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. + /// Might be None for console-redirect, but that's only a consequence of testing environments ATM. + auth: Option, + server_params: StartupMessageParams, + + /// Console redirect sets user and database, we shouldn't re-use those from the params. + skip_db_user: bool, +} + +/// Contains only the data needed to establish a secure connection to compute. +#[derive(Clone)] +pub struct ConnectInfo { + pub host_addr: Option, + pub host: Host, + pub port: u16, + pub ssl_mode: SslMode, +} /// Creation and initialization routines. -impl ConnCfg { - pub(crate) fn new(host: String, port: u16) -> Self { - Self(Box::new(postgres_client::Config::new(host, port))) - } - - /// Reuse password or auth keys from the other config. - pub(crate) fn reuse_password(&mut self, other: Self) { - if let Some(password) = other.get_password() { - self.password(password); - } - - if let Some(keys) = other.get_auth_keys() { - self.auth_keys(keys); +impl AuthInfo { + pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self { + let mut server_params = StartupMessageParams::default(); + server_params.insert("database", db); + server_params.insert("user", user); + Self { + auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())), + server_params, + skip_db_user: true, } } - pub(crate) fn get_host(&self) -> Host { - match self.0.get_host() { - postgres_client::config::Host::Tcp(s) => s.into(), + pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self { + Self { + auth: match keys { + ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => { + Some(Auth::Scram(Box::new(*auth_keys))) + } + ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None, + }, + server_params: StartupMessageParams::default(), + skip_db_user: false, } } +} + +impl ConnectInfo { + pub fn to_postgres_client_config(&self) -> postgres_client::Config { + let mut config = postgres_client::Config::new(self.host.to_string(), self.port); + config.ssl_mode(self.ssl_mode); + if let Some(host_addr) = self.host_addr { + config.set_host_addr(host_addr); + } + config + } +} + +impl AuthInfo { + fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config { + match &self.auth { + Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)), + Some(Auth::Password(pw)) => config.password(pw), + None => &mut config, + }; + for (k, v) in self.server_params.iter() { + config.set_param(k, v); + } + config + } /// Apply startup message params to the connection config. pub(crate) fn set_startup_params( @@ -132,27 +181,26 @@ impl ConnCfg { arbitrary_params: bool, ) { if !arbitrary_params { - self.set_param("client_encoding", "UTF8"); + self.server_params.insert("client_encoding", "UTF8"); } for (k, v) in params.iter() { match k { // Only set `user` if it's not present in the config. // Console redirect auth flow takes username from the console's response. - "user" if self.user_is_set() => {} - "database" if self.db_is_set() => {} + "user" | "database" if self.skip_db_user => {} "options" => { if let Some(options) = filtered_options(v) { - self.set_param(k, &options); + self.server_params.insert(k, &options); } } "user" | "database" | "application_name" | "replication" => { - self.set_param(k, v); + self.server_params.insert(k, v); } // if we allow arbitrary params, then we forward them through. // this is a flag for a period of backwards compatibility k if arbitrary_params => { - self.set_param(k, v); + self.server_params.insert(k, v); } _ => {} } @@ -160,25 +208,13 @@ impl ConnCfg { } } -impl std::ops::Deref for ConnCfg { - type Target = postgres_client::Config; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// For now, let's make it easier to setup the config. -impl std::ops::DerefMut for ConnCfg { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl ConnCfg { - /// Establish a raw TCP connection to the compute node. - async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { - use postgres_client::config::Host; +impl ConnectInfo { + /// Establish a raw TCP+TLS connection to the compute node. + async fn connect_raw( + &self, + config: &ComputeConfig, + ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { + let timeout = config.timeout; // wrap TcpStream::connect with timeout let connect_with_timeout = |addrs| { @@ -208,34 +244,32 @@ impl ConnCfg { // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. - let port = self.0.get_port(); - let host = self.0.get_host(); + let port = self.port; + let host = &*self.host; - let host = match host { - Host::Tcp(host) => host.as_str(), - }; - - let addrs = match self.0.get_host_addr() { + let addrs = match self.host_addr { Some(addr) => vec![SocketAddr::new(addr, port)], None => lookup_host((host, port)).await?.collect(), }; match connect_once(&*addrs).await { - Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)), + Ok((sockaddr, stream)) => Ok(( + sockaddr, + tls::connect_tls(stream, self.ssl_mode, config, host).await?, + )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); - Err(err) + Err(TlsError::Connection(err)) } } } } -type RustlsStream = >::Stream; +type RustlsStream = >::Stream; pub(crate) struct PostgresConnection { /// Socket connected to a compute node. - pub(crate) stream: - postgres_client::maybe_tls_stream::MaybeTlsStream, + pub(crate) stream: MaybeTlsStream, /// PostgreSQL connection parameters. pub(crate) params: std::collections::HashMap, /// Query cancellation token. @@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection { _guage: NumDbConnectionsGuard<'static>, } -impl ConnCfg { +impl ConnectInfo { /// Connect to a corresponding compute node. pub(crate) async fn connect( &self, ctx: &RequestContext, aux: MetricsAuxInfo, + auth: &AuthInfo, config: &ComputeConfig, user_info: ComputeUserInfo, ) -> Result { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; - drop(pause); + let mut tmp_config = auth.enrich(self.to_postgres_client_config()); + // we setup SSL early in `ConnectInfo::connect_raw`. + tmp_config.ssl_mode(SslMode::Disable); - let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, - host, - )?; - - // connect_raw() will not use TLS if sslmode is "disable" let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let connection = self.0.connect_raw(stream, tls).await?; + let (socket_addr, stream) = self.connect_raw(config).await?; + let connection = tmp_config.connect_raw(stream, NoTls).await?; drop(pause); let RawConnection { @@ -282,13 +311,14 @@ impl ConnCfg { tracing::Span::current().record("pid", tracing::field::display(process_id)); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); - let stream = stream.into_inner(); + let MaybeTlsStream::Raw(stream) = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( cold_start_info = ctx.cold_start_info().as_str(), - "connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", - self.0.get_ssl_mode(), + "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", + self.host, + self.ssl_mode, ctx.get_proxy_latency(), ctx.get_testodrome_id().unwrap_or_default(), ); @@ -299,11 +329,11 @@ impl ConnCfg { socket_addr, CancelToken { socket_config: None, - ssl_mode: self.0.get_ssl_mode(), + ssl_mode: self.ssl_mode, process_id, secret_key, }, - host.to_string(), + self.host.to_string(), user_info, ); diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs new file mode 100644 index 0000000000..000d75fca5 --- /dev/null +++ b/proxy/src/compute/tls.rs @@ -0,0 +1,63 @@ +use futures::FutureExt; +use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; +use postgres_client::tls::{MakeTlsConnect, TlsConnect}; +use rustls::pki_types::InvalidDnsNameError; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::pqproto::request_tls; +use crate::proxy::retry::CouldRetry; + +#[derive(Debug, Error)] +pub enum TlsError { + #[error(transparent)] + Dns(#[from] InvalidDnsNameError), + #[error(transparent)] + Connection(#[from] std::io::Error), + #[error("TLS required but not provided")] + Required, +} + +impl CouldRetry for TlsError { + fn could_retry(&self) -> bool { + match self { + TlsError::Dns(_) => false, + TlsError::Connection(err) => err.could_retry(), + // perhaps compute didn't realise it supports TLS? + TlsError::Required => true, + } + } +} + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: &T, + host: &str, +) -> Result, TlsError> +where + S: AsyncRead + AsyncWrite + Unpin + Send, + T: MakeTlsConnect< + S, + Error = InvalidDnsNameError, + TlsConnect: TlsConnect, + >, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer | SslMode::Require => {} + } + + if !request_tls(&mut stream).await? { + if SslMode::Require == mode { + return Err(TlsError::Required); + } + + return Ok(MaybeTlsStream::Raw(stream)); + } + + Ok(MaybeTlsStream::Tls( + tls.make_tls_connect(host)?.connect(stream).boxed().await?, + )) +} diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index f2484b54b8..324dcf5824 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -210,20 +210,20 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let (node_info, user_info, _ip_allowlist) = match backend + let (node_info, mut auth_info, user_info) = match backend .authenticate(ctx, &config.authentication_config, &mut stream) .await { Ok(auth_result) => auth_result, Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; + auth_info.set_startup_params(¶ms, true); let node = connect_to_compute( ctx, &TcpMechanism { user_info, - params_compat: true, - params: ¶ms, + auth: auth_info, locks: &config.connect_compute_locks, }, &node_info, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index da548d6b2c..cf2d9fba14 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -261,24 +261,18 @@ impl NeonControlPlaneClient { Some(_) => SslMode::Require, None => SslMode::Disable, }; - let host_name = match body.server_name { - Some(host) => host, - None => host.to_owned(), + let host = match body.server_name { + Some(host) => host.into(), + None => host.into(), }; - // Don't set anything but host and port! This config will be cached. - // We'll set username and such later using the startup message. - // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(host_name, port); - - if let Some(addr) = host_addr { - config.set_host_addr(addr); - } - - config.ssl_mode(ssl_mode); - let node = NodeInfo { - config, + conn_info: compute::ConnectInfo { + host_addr, + host, + port, + ssl_mode, + }, aux: body.aux, }; diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index ece7153fce..aeea57f2fc 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use std::sync::Arc; use futures::TryFutureExt; +use postgres_client::config::SslMode; use thiserror::Error; use tokio_postgres::Client; use tracing::{Instrument, error, info, info_span, warn}; @@ -14,6 +15,7 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, @@ -24,9 +26,9 @@ use crate::control_plane::{ RoleAccessControl, }; use crate::intern::RoleNameInt; +use crate::scram; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; -use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { @@ -87,8 +89,7 @@ impl MockControlPlane { .await? { info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); - secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + scram::ServerSecret::parse(&entry).map(AuthSecret::Scram) } else { warn!("user '{role}' does not exist"); None @@ -170,25 +171,23 @@ impl MockControlPlane { async fn do_wake_compute(&self) -> Result { let port = self.endpoint.port().unwrap_or(5432); - let mut config = match self.endpoint.host_str() { - None => { - let mut config = compute::ConnCfg::new("localhost".to_string(), port); - config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST)); - config - } - Some(host) => { - let mut config = compute::ConnCfg::new(host.to_string(), port); - if let Ok(addr) = IpAddr::from_str(host) { - config.set_host_addr(addr); - } - config - } + let conn_info = match self.endpoint.host_str() { + None => ConnectInfo { + host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + host: "localhost".into(), + port, + ssl_mode: SslMode::Disable, + }, + Some(host) => ConnectInfo { + host_addr: IpAddr::from_str(host).ok(), + host: host.into(), + port, + ssl_mode: SslMode::Disable, + }, }; - config.ssl_mode(postgres_client::config::SslMode::Disable); - let node = NodeInfo { - config, + conn_info, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -266,12 +265,3 @@ impl super::ControlPlaneApi for MockControlPlane { self.do_wake_compute().map_ok(Cached::new_uncached).await } } - -fn parse_md5(input: &str) -> Option<[u8; 16]> { - let text = input.strip_prefix("md5")?; - - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - - Some(bytes) -} diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 7ff093d9dc..ad10cf4257 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,8 +11,8 @@ pub(crate) mod errors; use std::sync::Arc; +use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; @@ -39,10 +39,6 @@ pub mod mgmt; /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum AuthSecret { - #[cfg(any(test, feature = "testing"))] - /// Md5 hash of user's password. - Md5([u8; 16]), - /// [SCRAM](crate::scram) authentication info. Scram(scram::ServerSecret), } @@ -63,13 +59,9 @@ pub(crate) struct AuthInfo { } /// Info for establishing a connection to a compute node. -/// This is what we get after auth succeeded, but not before! #[derive(Clone)] pub(crate) struct NodeInfo { - /// Compute node connection params. - /// It's sad that we have to clone this, but this will improve - /// once we migrate to a bespoke connection logic. - pub(crate) config: compute::ConnCfg, + pub(crate) conn_info: compute::ConnectInfo, /// Labels for proxy's metrics. pub(crate) aux: MetricsAuxInfo, @@ -79,26 +71,14 @@ impl NodeInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, + auth: &compute::AuthInfo, config: &ComputeConfig, user_info: ComputeUserInfo, ) -> Result { - self.config - .connect(ctx, self.aux.clone(), config, user_info) + self.conn_info + .connect(ctx, self.aux.clone(), auth, config, user_info) .await } - - pub(crate) fn reuse_settings(&mut self, other: Self) { - self.config.reuse_password(other.config); - } - - pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) { - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => self.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), - ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config, - }; - } } #[derive(Copy, Clone, Default)] diff --git a/proxy/src/pglb/connect_compute.rs b/proxy/src/pglb/connect_compute.rs index 1d6ca5fbb3..1807cdff0e 100644 --- a/proxy/src/pglb/connect_compute.rs +++ b/proxy/src/pglb/connect_compute.rs @@ -2,8 +2,8 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection}; +use crate::auth::backend::ComputeUserInfo; +use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection}; use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; @@ -13,7 +13,6 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; @@ -48,8 +47,6 @@ pub(crate) trait ConnectMechanism { node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, ) -> Result; - - fn update_connect_config(&self, conf: &mut compute::ConnCfg); } #[async_trait] @@ -58,24 +55,17 @@ pub(crate) trait ComputeConnectBackend { &self, ctx: &RequestContext, ) -> Result; - - fn get_keys(&self) -> &ComputeCredentialKeys; } -pub(crate) struct TcpMechanism<'a> { - pub(crate) params_compat: bool, - - /// KV-dictionary with PostgreSQL connection params. - pub(crate) params: &'a StartupMessageParams, - +pub(crate) struct TcpMechanism { + pub(crate) auth: AuthInfo, /// connect_to_compute concurrency lock pub(crate) locks: &'static ApiLocks, - pub(crate) user_info: ComputeUserInfo, } #[async_trait] -impl ConnectMechanism for TcpMechanism<'_> { +impl ConnectMechanism for TcpMechanism { type Connection = PostgresConnection; type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; @@ -90,13 +80,12 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; - permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await) - } - - fn update_connect_config(&self, config: &mut compute::ConnCfg) { - config.set_startup_params(self.params, self.params_compat); + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; + permit.release_result( + node_info + .connect(ctx, &self.auth, config, self.user_info.clone()) + .await, + ) } } @@ -114,12 +103,9 @@ where M::Error: From, { let mut num_retries = 0; - let mut node_info = + let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.set_keys(user_info.get_keys()); - mechanism.update_connect_config(&mut node_info.config); - // try once let err = match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { @@ -155,14 +141,9 @@ where } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node debug!("compute node's state has likely changed; requesting a wake-up"); - let old_node_info = invalidate_cache(node_info); + invalidate_cache(node_info); // TODO: increment num_retries? - let mut node_info = - wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.reuse_settings(old_node_info); - - mechanism.update_connect_config(&mut node_info.config); - node_info + wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await? }; // now that we have a new node, try connect to it repeatedly. diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 43074bf208..ad99eecda5 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -8,7 +8,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; use rand::distributions::{Distribution, Standard}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; pub type ErrorCode = [u8; 5]; @@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion { } } +/// +const MAX_STARTUP_PACKET_LENGTH: usize = 10000; +const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; +/// +const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); +/// +const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); +/// +const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + +/// This first reads the startup message header, is 8 bytes. +/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. +/// +/// The length value is inclusive of the header. For example, +/// an empty message will always have length 8. +#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, +} + /// read the type from the stream using zerocopy. /// /// not cancel safe. @@ -66,32 +88,38 @@ macro_rules! read { }}; } +/// Returns true if TLS is supported. +/// +/// This is not cancel safe. +pub async fn request_tls(stream: &mut S) -> io::Result +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let payload = StartupHeader { + len: 8.into(), + version: NEGOTIATE_SSL_CODE, + }; + stream.write_all(payload.as_bytes()).await?; + stream.flush().await?; + + // we expect back either `S` or `N` as a single byte. + let mut res = *b"0"; + stream.read_exact(&mut res).await?; + + debug_assert!( + res == *b"S" || res == *b"N", + "unexpected SSL negotiation response: {}", + char::from(res[0]), + ); + + // S for SSL. + Ok(res == *b"S") +} + pub async fn read_startup(stream: &mut S) -> io::Result where S: AsyncRead + Unpin, { - /// - const MAX_STARTUP_PACKET_LENGTH: usize = 10000; - const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; - /// - const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); - /// - const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); - /// - const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); - - /// This first reads the startup message header, is 8 bytes. - /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. - /// - /// The length value is inclusive of the header. For example, - /// an empty message will always have length 8. - #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] - #[repr(C)] - struct StartupHeader { - len: big_endian::U32, - version: ProtocolVersion, - } - let header = read!(stream => StartupHeader); // @@ -564,9 +592,8 @@ mod tests { use tokio::io::{AsyncWriteExt, duplex}; use zerocopy::IntoBytes; - use crate::pqproto::{FeStartupPacket, read_message, read_startup}; - use super::ProtocolVersion; + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; #[tokio::test] async fn reject_large_startup() { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0e138cc0c7..0e00c4f97e 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -358,21 +358,19 @@ pub(crate) async fn handle_client( } }; - let compute_user_info = match &user_info { - auth::Backend::ControlPlane(_, info) => &info.info, + let creds = match &user_info { + auth::Backend::ControlPlane(_, creds) => creds, auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), }; - let params_compat = compute_user_info - .options - .get(NeonOptions::PARAMS_COMPAT) - .is_some(); + let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); + let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys); + auth_info.set_startup_params(¶ms, params_compat); let res = connect_to_compute( ctx, &TcpMechanism { - user_info: compute_user_info.clone(), - params_compat, - params: ¶ms, + user_info: creds.info.clone(), + auth: auth_info, locks: &config.connect_compute_locks, }, &user_info, diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 01e603ec14..0f19944afa 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError { fn could_retry(&self) -> bool { match self { compute::ConnectionError::Postgres(err) => err.could_retry(), - compute::ConnectionError::CouldNotConnect(err) => err.could_retry(), + compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), - _ => false, + compute::ConnectionError::TooManyConnectionAttempts(_) => false, } } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index e5db0013a7..028247a97d 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -8,7 +8,7 @@ use std::time::Duration; use anyhow::{Context, bail}; use async_trait::async_trait; use http::StatusCode; -use postgres_client::config::SslMode; +use postgres_client::config::{AuthKeys, ScramKeys, SslMode}; use postgres_client::tls::{MakeTlsConnect, NoTls}; use retry::{ShouldRetryWakeCompute, retry_after}; use rstest::rstest; @@ -29,7 +29,6 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::pglb::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; @@ -72,13 +71,14 @@ struct ClientConfig<'a> { hostname: &'a str, } -type TlsConnect = >::TlsConnect; +type TlsConnect = >::TlsConnect; impl ClientConfig<'_> { fn make_tls_connect(self) -> anyhow::Result> { - let mut mk = MakeRustlsConnect::new(self.config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; - Ok(tls) + Ok(crate::tls::postgres_rustls::make_tls_connect( + &self.config, + self.hostname, + )?) } } @@ -497,8 +497,6 @@ impl ConnectMechanism for TestConnectMechanism { x => panic!("expecting action {x:?}, connect is called instead"), } } - - fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {} } impl TestControlPlaneClient for TestConnectMechanism { @@ -557,7 +555,12 @@ impl TestControlPlaneClient for TestConnectMechanism { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { - config: compute::ConnCfg::new("test".to_owned(), 5432), + conn_info: compute::ConnectInfo { + host: "test".into(), + port: 5432, + ssl_mode: SslMode::Disable, + host_addr: None, + }, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -581,7 +584,10 @@ fn helper_create_connect_info( user: "user".into(), options: NeonOptions::parse_options_raw(""), }, - keys: ComputeCredentialKeys::Password("password".into()), + keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys { + client_key: [0; 32], + server_key: [0; 32], + })), }, ) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 748e0ce6f2..a0e782dab0 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -23,7 +23,6 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; -use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; @@ -305,12 +304,13 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = local_backend.node_info.clone(); - let (key, jwk) = create_random_jwk(); - let config = node_info - .config + let mut config = local_backend + .node_info + .conn_info + .to_postgres_client_config(); + config .user(&conn_info.user_info.user) .dbname(&conn_info.dbname) .set_param( @@ -322,7 +322,7 @@ impl PoolingBackend { ); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = config.connect(postgres_client::NoTls).await?; + let (client, connection) = config.connect(&postgres_client::NoTls).await?; drop(pause); let pid = client.get_process_id(); @@ -336,7 +336,7 @@ impl PoolingBackend { connection, key, conn_id, - node_info.aux.clone(), + local_backend.node_info.aux.clone(), ); { @@ -512,19 +512,16 @@ impl ConnectMechanism for TokioMechanism { node_info: &CachedNodeInfo, compute_config: &ComputeConfig, ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - let mut config = (*node_info.config).clone(); + let mut config = node_info.conn_info.to_postgres_client_config(); let config = config .user(&self.conn_info.user_info.user) .dbname(&self.conn_info.dbname) .connect_timeout(compute_config.timeout); - let mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(mk_tls).await; + let res = config.connect(compute_config).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -548,8 +545,6 @@ impl ConnectMechanism for TokioMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } struct HyperMechanism { @@ -573,20 +568,20 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result { - let host_addr = node_info.config.get_host_addr(); - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let host_addr = node_info.conn_info.host_addr; + let host = &node_info.conn_info.host; + let permit = self.locks.get_permit(host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let tls = if node_info.config.get_ssl_mode() == SslMode::Disable { + let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { None } else { Some(&config.tls) }; - let port = node_info.config.get_port(); - let res = connect_http2(host_addr, &host, port, config.timeout, tls).await; + let port = node_info.conn_info.port; + let res = connect_http2(host_addr, host, port, config.timeout, tls).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -609,8 +604,6 @@ impl ConnectMechanism for HyperMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } async fn connect_http2( diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 87176ff7d6..dd8cf052c5 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -23,12 +23,12 @@ use super::conn_pool_lib::{ Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool, GlobalConnPool, }; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::Metrics; -use crate::tls::postgres_rustls::MakeRustlsConnect; -type TlsStream = >::Stream; +type TlsStream = >::Stream; #[derive(Debug, Clone)] pub(crate) struct ConnInfoWithAuth { diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index 013b307f0b..9269ad8a06 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -2,10 +2,11 @@ use std::convert::TryFrom; use std::sync::Arc; use postgres_client::tls::MakeTlsConnect; -use rustls::ClientConfig; -use rustls::pki_types::ServerName; +use rustls::pki_types::{InvalidDnsNameError, ServerName}; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::config::ComputeConfig; + mod private { use std::future::Future; use std::io; @@ -123,36 +124,27 @@ mod private { } } -/// A `MakeTlsConnect` implementation using `rustls`. -/// -/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. -#[derive(Clone)] -pub struct MakeRustlsConnect { - pub config: Arc, -} - -impl MakeRustlsConnect { - /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. - #[must_use] - pub fn new(config: Arc) -> Self { - Self { config } - } -} - -impl MakeTlsConnect for MakeRustlsConnect +impl MakeTlsConnect for ComputeConfig where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = private::RustlsStream; type TlsConnect = private::RustlsConnect; - type Error = rustls::pki_types::InvalidDnsNameError; + type Error = InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> Result { - ServerName::try_from(hostname).map(|dns_name| { - private::RustlsConnect(private::RustlsConnectData { - hostname: dns_name.to_owned(), - connector: Arc::clone(&self.config).into(), - }) - }) + fn make_tls_connect(&self, hostname: &str) -> Result { + make_tls_connect(&self.tls, hostname) } } + +pub fn make_tls_connect( + tls: &Arc, + hostname: &str, +) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: tls.clone().into(), + }) + }) +} From 470c7d5e0e5f70fefeca9a3e9b7dbd380a78acc1 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Fri, 6 Jun 2025 12:48:01 +0100 Subject: [PATCH 08/12] endpoint_storage: default listen port, allow inline config (#12152) Related: https://github.com/neondatabase/cloud/issues/27195 --- endpoint_storage/src/main.rs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/endpoint_storage/src/main.rs b/endpoint_storage/src/main.rs index 3d1f05575d..399a4ec31e 100644 --- a/endpoint_storage/src/main.rs +++ b/endpoint_storage/src/main.rs @@ -3,7 +3,8 @@ //! This service is deployed either as a separate component or as part of compute image //! for large computes. mod app; -use anyhow::Context; +use anyhow::{Context, bail}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use tracing::info; use utils::logging; @@ -12,9 +13,14 @@ const fn max_upload_file_limit() -> usize { 100 * 1024 * 1024 } +const fn listen() -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243) +} + #[derive(serde::Deserialize)] #[serde(tag = "type")] struct Config { + #[serde(default = "listen")] listen: std::net::SocketAddr, pemfile: camino::Utf8PathBuf, #[serde(flatten)] @@ -31,13 +37,21 @@ async fn main() -> anyhow::Result<()> { logging::Output::Stdout, )?; - let config: String = std::env::args().skip(1).take(1).collect(); - if config.is_empty() { - anyhow::bail!("Usage: endpoint_storage config.json") - } - info!("Reading config from {config}"); - let config = std::fs::read_to_string(config.clone())?; - let config: Config = serde_json::from_str(&config).context("parsing config")?; + // Allow either passing filename or inline config (for k8s helm chart) + let args: Vec = std::env::args().skip(1).collect(); + let config: Config = if args.len() == 1 && args[0].ends_with(".json") { + info!("Reading config from {}", args[0]); + let config = std::fs::read_to_string(args[0].clone())?; + serde_json::from_str(&config).context("parsing config")? + } else if !args.is_empty() && args[0].starts_with("--config=") { + info!("Reading inline config"); + let config = args.join(" "); + let config = config.strip_prefix("--config=").unwrap(); + serde_json::from_str(config).context("parsing config")? + } else { + bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON"); + }; + info!("Reading pemfile from {}", config.pemfile.clone()); let pemfile = std::fs::read(config.pemfile.clone())?; info!("Loading public key from {}", config.pemfile.clone()); From df7e301a5401ac1da2792b00ace7323f913b4fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Fri, 6 Jun 2025 13:54:07 +0200 Subject: [PATCH 09/12] safekeeper: special error if a timeline has been deleted (#12155) We might delete timelines on safekeepers before we are deleting them on pageservers. This should be an exceptional situation, but can occur. As the first step to improve behaviour here, emit a special error that is less scary/obscure than "was not found in global map". It is for example emitted when the pageserver tries to run `IDENTIFY_SYSTEM` on a timeline that has been deleted on the safekeeper. Found when analyzing the failure of `test_scrubber_physical_gc_timeline_deletion` when enabling `--timelines-onto-safekeepers` on the pytests. Due to safekeeper restarts, there is no hard guarantee that we will keep issuing this error, so we need to think of something better if we start encountering this in staging/prod. But I would say that the introduction of `--timelines-onto-safekeepers` in the pytests and into staging won't change much about this: we are already deleting timelines from there. In `test_scrubber_physical_gc_timeline_deletion`, we'd just be leaking the timeline before on the safekeepers. Part of #11712 --- safekeeper/src/timeline.rs | 2 ++ safekeeper/src/timelines_global_map.rs | 8 +++++++- test_runner/regress/test_safekeeper_deletion.py | 2 ++ test_runner/regress/test_wal_acceptor.py | 2 ++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 588bd4f2c9..2bee41537f 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -395,6 +395,8 @@ pub enum TimelineError { Cancelled(TenantTimelineId), #[error("Timeline {0} was not found in global map")] NotFound(TenantTimelineId), + #[error("Timeline {0} has been deleted")] + Deleted(TenantTimelineId), #[error("Timeline {0} creation is in progress")] CreationInProgress(TenantTimelineId), #[error("Timeline {0} exists on disk, but wasn't loaded on startup")] diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index e3f7d88f7c..6e41ada1b3 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -78,7 +78,13 @@ impl GlobalTimelinesState { Some(GlobalMapTimeline::CreationInProgress) => { Err(TimelineError::CreationInProgress(*ttid)) } - None => Err(TimelineError::NotFound(*ttid)), + None => { + if self.has_tombstone(ttid) { + Err(TimelineError::Deleted(*ttid)) + } else { + Err(TimelineError::NotFound(*ttid)) + } + } } } diff --git a/test_runner/regress/test_safekeeper_deletion.py b/test_runner/regress/test_safekeeper_deletion.py index b681a86103..bc79969e9a 100644 --- a/test_runner/regress/test_safekeeper_deletion.py +++ b/test_runner/regress/test_safekeeper_deletion.py @@ -30,6 +30,7 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) @@ -198,6 +199,7 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder) env.pageserver.allowed_errors.extend( [ ".*Timeline.*was cancelled.*", + ".*Timeline.*has been deleted.*", ".*Timeline.*was not found.*", ] ) diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 6a7c7a8bef..b9183286af 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -433,6 +433,7 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder): env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) @@ -1934,6 +1935,7 @@ def test_membership_api(neon_env_builder: NeonEnvBuilder): env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) From 6dd84041a1b93e8033abe75e93867db11069e91d Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 6 Jun 2025 13:49:29 +0100 Subject: [PATCH 10/12] refactor and simplify the invalidation notification structure (#12154) The current cache invalidation messages are far too specific. They should be more generic since it only ends up triggering a `GetEndpointAccessControl` message anyway. Mappings: * `/allowed_ips_updated`, `/block_public_or_vpc_access_updated`, and `/allowed_vpc_endpoints_updated_for_projects` -> `/project_settings_update`. * `/allowed_vpc_endpoints_updated_for_org` -> `/account_settings_update`. * `/password_updated` -> `/role_setting_update`. I've also introduced `/endpoint_settings_update`. All message types support singular or multiple entries, which allows us to simplify things both on our side and on cplane side. I'm opening a PR to cplane to apply the above mappings, but for now using the old phrases to allow both to roll out independently. This change is inspired by my need to add yet another cached entry to `GetEndpointAccessControl` for https://github.com/neondatabase/cloud/issues/28333 --- proxy/src/cache/project_info.rs | 8 + proxy/src/metrics.rs | 10 +- proxy/src/redis/notifications.rs | 241 +++++++++++++++++-------------- 3 files changed, 145 insertions(+), 114 deletions(-) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 81c88e3ddd..9a4be2f904 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); @@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { + info!("invalidating endpoint access for `{endpoint_id}`"); + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); + } + } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 4b22c912eb..4c340edfd5 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -610,11 +610,11 @@ pub enum RedisEventsCount { BranchCreated, ProjectCreated, CancelSession, - PasswordUpdate, - AllowedIpsUpdate, - AllowedVpcEndpointIdsUpdateForProjects, - AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, - BlockPublicOrVpcAccessUpdate, + InvalidateRole, + InvalidateEndpoint, + InvalidateProject, + InvalidateProjects, + InvalidateOrg, } pub struct ThreadPoolWorkers(usize); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a9d6b40603..6c8260027f 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use futures::StreamExt; use redis::aio::PubSub; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio_util::sync::CancellationToken; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -27,42 +27,37 @@ struct NotificationHeader<'a> { topic: &'a str, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -pub(crate) enum Notification { +enum Notification { #[serde( - rename = "/allowed_ips_updated", + rename = "/account_settings_update", + alias = "/allowed_vpc_endpoints_updated_for_org", deserialize_with = "deserialize_json_string" )] - AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate, - }, + AccountSettingsUpdate(InvalidateAccount), + #[serde( - rename = "/block_public_or_vpc_access_updated", + rename = "/endpoint_settings_update", deserialize_with = "deserialize_json_string" )] - BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated, - }, + EndpointSettingsUpdate(InvalidateEndpoint), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_org", + rename = "/project_settings_update", + alias = "/allowed_ips_updated", + alias = "/block_public_or_vpc_access_updated", + alias = "/allowed_vpc_endpoints_updated_for_projects", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg, - }, + ProjectSettingsUpdate(InvalidateProject), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_projects", + rename = "/role_setting_update", + alias = "/password_updated", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects, - }, - #[serde( - rename = "/password_updated", - deserialize_with = "deserialize_json_string" - )] - PasswordUpdate { password_update: PasswordUpdate }, + RoleSettingUpdate(InvalidateRole), #[serde( other, @@ -72,28 +67,56 @@ pub(crate) enum Notification { UnknownTopic, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedIpsUpdate { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateEndpoint { + EndpointId(EndpointIdInt), + EndpointIds(Vec), +} +impl std::ops::Deref for InvalidateEndpoint { + type Target = [EndpointIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::EndpointId(id) => std::slice::from_ref(id), + Self::EndpointIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct BlockPublicOrVpcAccessUpdated { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateProject { + ProjectId(ProjectIdInt), + ProjectIds(Vec), +} +impl std::ops::Deref for InvalidateProject { + type Target = [ProjectIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::ProjectId(id) => std::slice::from_ref(id), + Self::ProjectIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForOrg { - account_id: AccountIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateAccount { + AccountId(AccountIdInt), + AccountIds(Vec), +} +impl std::ops::Deref for InvalidateAccount { + type Target = [AccountIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::AccountId(id) => std::slice::from_ref(id), + Self::AccountIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForProjects { - project_ids: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct PasswordUpdate { +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct InvalidateRole { project_id: ProjectIdInt, role_name: RoleNameInt, } @@ -177,41 +200,29 @@ impl MessageHandler { tracing::debug!(?msg, "received a message"); match msg { - Notification::AllowedIpsUpdate { .. } - | Notification::PasswordUpdate { .. } - | Notification::BlockPublicOrVpcAccessUpdated { .. } - | Notification::AllowedVpcEndpointsUpdatedForOrg { .. } - | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => { + Notification::RoleSettingUpdate { .. } + | Notification::EndpointSettingsUpdate { .. } + | Notification::ProjectSettingsUpdate { .. } + | Notification::AccountSettingsUpdate { .. } => { invalidate_cache(self.cache.clone(), msg.clone()); - if matches!(msg, Notification::AllowedIpsUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedIpsUpdate); - } else if matches!(msg, Notification::PasswordUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::PasswordUpdate); - } else if matches!( - msg, - Notification::AllowedVpcEndpointsUpdatedForProjects { .. } - ) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); - } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg); - } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate); + + let m = &Metrics::get().proxy.redis_events_count; + match msg { + Notification::RoleSettingUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateRole); + } + Notification::EndpointSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateEndpoint); + } + Notification::ProjectSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateProject); + } + Notification::AccountSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateOrg); + } + Notification::UnknownTopic => {} } + // TODO: add additional metrics for the other event types. // It might happen that the invalid entry is on the way to be cached. @@ -233,30 +244,23 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { project_id }, - } - | Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, - } => cache.invalidate_endpoint_access_for_project(project_id), - Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, - } => cache.invalidate_endpoint_access_for_org(account_id), - Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: - AllowedVpcEndpointsUpdatedForProjects { project_ids }, - } => { - for project in project_ids { - cache.invalidate_endpoint_access_for_project(project); - } - } - Notification::PasswordUpdate { - password_update: - PasswordUpdate { - project_id, - role_name, - }, - } => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::EndpointSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access(id)), + + Notification::AccountSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)), + + Notification::ProjectSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)), + + Notification::RoleSettingUpdate(InvalidateRole { + project_id, + role_name, + }) => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::UnknownTopic => unreachable!(), } } @@ -353,11 +357,32 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { - project_id: (&project_id).into() - } - } + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into())) + ); + + Ok(()) + } + + #[test] + fn parse_multiple_projects() -> anyhow::Result<()> { + let project_id1: ProjectId = "new_project1".into(); + let project_id2: ProjectId = "new_project2".into(); + let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}"); + let text = json!({ + "type": "message", + "topic": "/allowed_vpc_endpoints_updated_for_projects", + "data": data, + "extre_fields": "something" + }) + .to_string(); + + let result: Notification = serde_json::from_str(&text)?; + assert_eq!( + result, + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![ + (&project_id1).into(), + (&project_id2).into() + ])) ); Ok(()) @@ -379,12 +404,10 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::PasswordUpdate { - password_update: PasswordUpdate { - project_id: (&project_id).into(), - role_name: (&role_name).into(), - } - } + Notification::RoleSettingUpdate(InvalidateRole { + project_id: (&project_id).into(), + role_name: (&role_name).into(), + }) ); Ok(()) From 3c7235669a3655edf12a47fc3d3b19e75826e6bd Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Fri, 6 Jun 2025 17:55:14 +0200 Subject: [PATCH 11/12] pageserver: don't delete parent shard files until split is committed (#12146) ## Problem If a shard split fails and must roll back, the tenant may hit a cold start as the parent shard's files have already been removed from local disk. External contribution with minor adjustments, see https://neondb.slack.com/archives/C08TE3203RQ/p1748246398269309. ## Summary of changes Keep the parent shard's files on local disk until the split has been committed, such that they are available if the spilt is rolled back. If all else fails, the files will be removed on the next Pageserver restart. This should also be fine in a mixed version: * New storcon, old Pageserver: the Pageserver will delete the files during the split, storcon will log an error when the cleanup detach fails. * Old storcon, new Pageserver: the Pageserver will leave the parent's files around until the next Pageserver restart. The change looks good to me, but shard splits are delicate so I'd like some extra eyes on this. --- pageserver/src/tenant/mgr.rs | 50 ++++++++--- pageserver/src/tenant/timeline.rs | 4 +- storage_controller/src/service.rs | 18 +++- test_runner/regress/test_sharding.py | 87 +++++++++++++++++++ .../regress/test_storage_controller.py | 4 +- 5 files changed, 146 insertions(+), 17 deletions(-) diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 86aef9b42c..186e0f4cdb 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -1671,7 +1671,12 @@ impl TenantManager { } } - // Phase 5: Shut down the parent shard, and erase it from disk + // Phase 5: Shut down the parent shard. We leave it on disk in case the split fails and we + // have to roll back to the parent shard, avoiding a cold start. It will be cleaned up once + // the storage controller commits the split, or if all else fails, on the next restart. + // + // TODO: We don't flush the ephemeral layer here, because the split is likely to succeed and + // catching up the parent should be reasonably quick. Consider using FreezeAndFlush instead. let (_guard, progress) = completion::channel(); match parent.shutdown(progress, ShutdownMode::Hard).await { Ok(()) => {} @@ -1679,11 +1684,6 @@ impl TenantManager { other.wait().await; } } - let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id); - let tmp_path = safe_rename_tenant_dir(&local_tenant_directory) - .await - .with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?; - self.background_purges.spawn(tmp_path); fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!( "failpoint" @@ -1846,42 +1846,70 @@ impl TenantManager { shutdown_all_tenants0(self.tenants).await } + /// Detaches a tenant, and removes its local files asynchronously. + /// + /// File removal is idempotent: even if the tenant has already been removed, this will still + /// remove any local files. This is used during shard splits, where we leave the parent shard's + /// files around in case we have to roll back the split. pub(crate) async fn detach_tenant( &self, conf: &'static PageServerConf, tenant_shard_id: TenantShardId, deletion_queue_client: &DeletionQueueClient, ) -> Result<(), TenantStateError> { - let tmp_path = self + if let Some(tmp_path) = self .detach_tenant0(conf, tenant_shard_id, deletion_queue_client) - .await?; - self.background_purges.spawn(tmp_path); + .await? + { + self.background_purges.spawn(tmp_path); + } Ok(()) } + /// Detaches a tenant. This renames the tenant directory to a temporary path and returns it, + /// allowing the caller to delete it asynchronously. Returns None if the dir is already removed. async fn detach_tenant0( &self, conf: &'static PageServerConf, tenant_shard_id: TenantShardId, deletion_queue_client: &DeletionQueueClient, - ) -> Result { + ) -> Result, TenantStateError> { let tenant_dir_rename_operation = |tenant_id_to_clean: TenantShardId| async move { let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean); + if !tokio::fs::try_exists(&local_tenant_directory).await? { + // If the tenant directory doesn't exist, it's already cleaned up. + return Ok(None); + } safe_rename_tenant_dir(&local_tenant_directory) .await .with_context(|| { format!("local tenant directory {local_tenant_directory:?} rename") }) + .map(Some) }; - let removal_result = remove_tenant_from_memory( + let mut removal_result = remove_tenant_from_memory( self.tenants, tenant_shard_id, tenant_dir_rename_operation(tenant_shard_id), ) .await; + // If the tenant was not found, it was likely already removed. Attempt to remove the tenant + // directory on disk anyway. For example, during shard splits, we shut down and remove the + // parent shard, but leave its directory on disk in case we have to roll back the split. + // + // TODO: it would be better to leave the parent shard attached until the split is committed. + // This will be needed by the gRPC page service too, such that a compute can continue to + // read from the parent shard until it's notified about the new child shards. See: + // . + if let Err(TenantStateError::SlotError(TenantSlotError::NotFound(_))) = removal_result { + removal_result = tenant_dir_rename_operation(tenant_shard_id) + .await + .map_err(TenantStateError::Other); + } + // Flush pending deletions, so that they have a good chance of passing validation // before this tenant is potentially re-attached elsewhere. deletion_queue_client.flush_advisory(); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 3522af2de0..0ff005fbb9 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> { /// Argument to [`Timeline::shutdown`]. #[derive(Debug, Clone, Copy)] pub(crate) enum ShutdownMode { - /// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then - /// also to remote storage. This method can easily take multiple seconds for a busy timeline. + /// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can + /// take multiple seconds for a busy timeline. /// /// While we are flushing, we continue to accept read I/O for LSNs ingested before /// the call to [`Timeline::shutdown`]. diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index cb29993e8c..06318a01b5 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -1108,7 +1108,8 @@ impl Service { observed } - /// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers. + /// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us + /// tenants from pageservers. /// /// This is safe to run in the background, because if we don't have this TenantShardId in our map of /// tenants, then it is probably something incompletely deleted before: we will not fight with any @@ -6211,7 +6212,11 @@ impl Service { } } - pausable_failpoint!("shard-split-pre-complete"); + fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict( + "failpoint".to_string() + ))); + + pausable_failpoint!("shard-split-pre-complete-pause"); // TODO: if the pageserver restarted concurrently with our split API call, // the actual generation of the child shard might differ from the generation @@ -6233,6 +6238,15 @@ impl Service { let (response, child_locations, waiters) = self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size); + // Notify all page servers to detach and clean up the old shards because they will no longer + // be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent + // Pageserver re-attach/startup. + let shards_to_cleanup = targets + .iter() + .map(|target| (target.parent_id, target.node.get_id())) + .collect(); + self.cleanup_locations(shards_to_cleanup).await; + // Send compute notifications for all the new shards let mut failed_notifications = Vec::new(); for (child_id, child_ps, stripe_size) in child_locations { diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 4c9887fb92..522e257ea5 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1836,3 +1836,90 @@ def test_sharding_gc( shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"]) log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}") assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn + + +def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder): + """ + Check that PageServer only deletes old shards after the split is committed such that it doesn't + have to download a lot of files during abort. + """ + DBNAME = "regression" + + init_shard_count = 4 + neon_env_builder.num_pageservers = init_shard_count + stripe_size = 32 + + env = neon_env_builder.init_start( + initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size + ) + + env.storage_controller.allowed_errors.extend( + [ + # All split failures log a warning when they enqueue the abort operation + ".*Enqueuing background abort.*", + # Tolerate any error logs that mention a failpoint + ".*failpoint.*", + ] + ) + + endpoint = env.endpoints.create("main") + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + # Write some initial data. + endpoint.safe_psql(f"CREATE DATABASE {DBNAME}") + endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + + for _ in range(1000): + endpoint.safe_psql( + "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False + ) + + # Record how many bytes we've downloaded before the split. + def collect_downloaded_bytes() -> list[float | None]: + downloaded_bytes = [] + for page_server in env.pageservers: + metric = page_server.http_client().get_metric_value( + "pageserver_remote_ondemand_downloaded_bytes_total" + ) + downloaded_bytes.append(metric) + return downloaded_bytes + + downloaded_bytes_before = collect_downloaded_bytes() + + # Attempt to split the tenant, but fail the split before it completes. + env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)")) + with pytest.raises(StorageControllerApiException): + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16) + + # Wait until split is aborted. + def check_split_is_aborted(): + tenants = env.storage_controller.tenant_list() + assert len(tenants) == 1 + shards = tenants[0]["shards"] + assert len(shards) == 4 + for shard in shards: + assert not shard["is_splitting"] + assert not shard["is_reconciling"] + + # Make sure all new shards have been deleted. + valid_shards = 0 + for ps in env.pageservers: + for tenant_dir in os.listdir(ps.workdir / "tenants"): + try: + tenant_shard_id = TenantShardId.parse(tenant_dir) + valid_shards += 1 + assert tenant_shard_id.shard_count == 4 + except ValueError: + log.info(f"{tenant_dir} is not valid tenant shard id") + assert valid_shards >= 4 + + wait_until(check_split_is_aborted) + + endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False) + + # Make sure we didn't download anything following the aborted split. + downloaded_bytes_after = collect_downloaded_bytes() + + assert downloaded_bytes_before == downloaded_bytes_after + endpoint.stop_and_destroy() diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 5e0dd780c3..8f3aa010e3 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -2956,7 +2956,7 @@ def test_storage_controller_leadership_transfer_during_split( env.storage_controller.allowed_errors.extend( [".*Unexpected child shard count.*", ".*Enqueuing background abort.*"] ) - pause_failpoint = "shard-split-pre-complete" + pause_failpoint = "shard-split-pre-complete-pause" env.storage_controller.configure_failpoints((pause_failpoint, "pause")) split_fut = executor.submit( @@ -3003,7 +3003,7 @@ def test_storage_controller_leadership_transfer_during_split( env.storage_controller.request( "PUT", f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints", - json=[{"name": "shard-split-pre-complete", "actions": "off"}], + json=[{"name": pause_failpoint, "actions": "off"}], headers=env.storage_controller.headers(TokenScope.ADMIN), ) From 7efd4554ab8e905cfd0cbf3e8a59f34fdfee7345 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Fri, 6 Jun 2025 19:08:02 +0100 Subject: [PATCH 12/12] endpoint_storage: allow bypassing s3 write check on startup (#12165) Related: https://github.com/neondatabase/cloud/issues/27195 --- Cargo.lock | 1 + endpoint_storage/Cargo.toml | 1 + endpoint_storage/src/main.rs | 36 ++++++++++++++++++++++++------------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f71af118c..3ee261e885 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2055,6 +2055,7 @@ dependencies = [ "axum-extra", "camino", "camino-tempfile", + "clap", "futures", "http-body-util", "itertools 0.10.5", diff --git a/endpoint_storage/Cargo.toml b/endpoint_storage/Cargo.toml index b2c9d51551..c2e21d02e2 100644 --- a/endpoint_storage/Cargo.toml +++ b/endpoint_storage/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true axum-extra.workspace = true axum.workspace = true camino.workspace = true +clap.workspace = true futures.workspace = true jsonwebtoken.workspace = true prometheus.workspace = true diff --git a/endpoint_storage/src/main.rs b/endpoint_storage/src/main.rs index 399a4ec31e..23b7343ff3 100644 --- a/endpoint_storage/src/main.rs +++ b/endpoint_storage/src/main.rs @@ -3,7 +3,8 @@ //! This service is deployed either as a separate component or as part of compute image //! for large computes. mod app; -use anyhow::{Context, bail}; +use anyhow::Context; +use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use tracing::info; use utils::logging; @@ -17,6 +18,18 @@ const fn listen() -> SocketAddr { SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243) } +#[derive(Parser)] +struct Args { + #[arg(exclusive = true)] + config_file: Option, + #[arg(long, default_value = "false", requires = "config")] + /// to allow testing k8s helm chart where we don't have s3 credentials + no_s3_check_on_startup: bool, + #[arg(long, value_name = "FILE")] + /// inline config mode for k8s helm chart + config: Option, +} + #[derive(serde::Deserialize)] #[serde(tag = "type")] struct Config { @@ -37,19 +50,16 @@ async fn main() -> anyhow::Result<()> { logging::Output::Stdout, )?; - // Allow either passing filename or inline config (for k8s helm chart) - let args: Vec = std::env::args().skip(1).collect(); - let config: Config = if args.len() == 1 && args[0].ends_with(".json") { - info!("Reading config from {}", args[0]); - let config = std::fs::read_to_string(args[0].clone())?; + let args = Args::parse(); + let config: Config = if let Some(config_path) = args.config_file { + info!("Reading config from {config_path}"); + let config = std::fs::read_to_string(config_path)?; serde_json::from_str(&config).context("parsing config")? - } else if !args.is_empty() && args[0].starts_with("--config=") { + } else if let Some(config) = args.config { info!("Reading inline config"); - let config = args.join(" "); - let config = config.strip_prefix("--config=").unwrap(); - serde_json::from_str(config).context("parsing config")? + serde_json::from_str(&config).context("parsing config")? } else { - bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON"); + anyhow::bail!("Supply either config file path or --config=inline-config"); }; info!("Reading pemfile from {}", config.pemfile.clone()); @@ -62,7 +72,9 @@ async fn main() -> anyhow::Result<()> { let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?; let cancel = tokio_util::sync::CancellationToken::new(); - app::check_storage_permissions(&storage, cancel.clone()).await?; + if !args.no_s3_check_on_startup { + app::check_storage_permissions(&storage, cancel.clone()).await?; + } let proxy = std::sync::Arc::new(endpoint_storage::Storage { auth,