Merge remote-tracking branch 'origin/main' into communicator-rewrite

This commit is contained in:
Heikki Linnakangas
2025-07-23 00:00:10 +03:00
190 changed files with 5161 additions and 1877 deletions

View File

@@ -66,6 +66,12 @@ class EndpointHttpClient(requests.Session):
res.raise_for_status()
return res.json()
def autoscaling_metrics(self):
res = self.get(f"http://localhost:{self.external_port}/autoscaling_metrics")
res.raise_for_status()
log.debug("raw compute metrics: %s", res.text)
return res.text
def prewarm_lfc_status(self) -> dict[str, str]:
res = self.get(self.prewarm_url)
res.raise_for_status()

View File

@@ -24,6 +24,7 @@ def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]:
# Some API calls not yet implemented.
# You may want to copy not-yet-implemented methods from the PR https://github.com/neondatabase/neon/pull/11305
@final
class NeonAPI:
def __init__(self, neon_api_key: str, neon_api_base_url: str):
self.__neon_api_key = neon_api_key
@@ -170,7 +171,7 @@ class NeonAPI:
protected: bool | None = None,
archived: bool | None = None,
init_source: str | None = None,
add_endpoint=True,
add_endpoint: bool = True,
) -> dict[str, Any]:
data: dict[str, Any] = {}
if add_endpoint:

View File

@@ -400,6 +400,7 @@ class NeonLocalCli(AbstractNeonCli):
timeout_in_seconds: int | None = None,
instance_id: int | None = None,
base_port: int | None = None,
handle_ps_local_disk_loss: bool | None = None,
):
cmd = ["storage_controller", "start"]
if timeout_in_seconds is not None:
@@ -408,6 +409,10 @@ class NeonLocalCli(AbstractNeonCli):
cmd.append(f"--instance-id={instance_id}")
if base_port is not None:
cmd.append(f"--base-port={base_port}")
if handle_ps_local_disk_loss is not None:
cmd.append(
f"--handle-ps-local-disk-loss={'true' if handle_ps_local_disk_loss else 'false'}"
)
return self.raw_cli(cmd)
def storage_controller_stop(self, immediate: bool, instance_id: int | None = None):

View File

@@ -1938,9 +1938,12 @@ class NeonStorageController(MetricsGetter, LogUtils):
timeout_in_seconds: int | None = None,
instance_id: int | None = None,
base_port: int | None = None,
handle_ps_local_disk_loss: bool | None = None,
) -> Self:
assert not self.running
self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port)
self.env.neon_cli.storage_controller_start(
timeout_in_seconds, instance_id, base_port, handle_ps_local_disk_loss
)
self.running = True
return self
@@ -2838,10 +2841,13 @@ class NeonProxiedStorageController(NeonStorageController):
timeout_in_seconds: int | None = None,
instance_id: int | None = None,
base_port: int | None = None,
handle_ps_local_disk_loss: bool | None = None,
) -> Self:
assert instance_id is not None and base_port is not None
self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port)
self.env.neon_cli.storage_controller_start(
timeout_in_seconds, instance_id, base_port, handle_ps_local_disk_loss
)
self.instances[instance_id] = {"running": True}
self.running = True
@@ -4121,6 +4127,294 @@ class NeonAuthBroker:
self._popen.kill()
class NeonLocalProxy(LogUtils):
"""
An object managing a local_proxy instance for rest broker testing.
The local_proxy serves as a direct connection to VanillaPostgres.
"""
def __init__(
self,
neon_binpath: Path,
test_output_dir: Path,
http_port: int,
metrics_port: int,
vanilla_pg: VanillaPostgres,
config_path: Path | None = None,
):
self.neon_binpath = neon_binpath
self.test_output_dir = test_output_dir
self.http_port = http_port
self.metrics_port = metrics_port
self.vanilla_pg = vanilla_pg
self.config_path = config_path or (test_output_dir / "local_proxy.json")
self.host = "127.0.0.1"
self.running = False
self.logfile = test_output_dir / "local_proxy.log"
self._popen: subprocess.Popen[bytes] | None = None
super().__init__(logfile=self.logfile)
def start(self) -> Self:
assert self._popen is None
assert not self.running
# Ensure vanilla_pg is running
if not self.vanilla_pg.is_running():
self.vanilla_pg.start()
args = [
str(self.neon_binpath / "local_proxy"),
"--http",
f"{self.host}:{self.http_port}",
"--metrics",
f"{self.host}:{self.metrics_port}",
"--postgres",
f"127.0.0.1:{self.vanilla_pg.default_options['port']}",
"--config-path",
str(self.config_path),
"--disable-pg-session-jwt",
]
logfile = open(self.logfile, "w")
self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile)
self.running = True
self._wait_until_ready()
return self
def stop(self) -> Self:
if self._popen is not None and self.running:
self._popen.terminate()
try:
self._popen.wait(timeout=5)
except subprocess.TimeoutExpired:
log.warning("failed to gracefully terminate local_proxy; killing")
self._popen.kill()
self.running = False
return self
def get_binary_version(self) -> str:
"""Get the version string of the local_proxy binary"""
try:
result = subprocess.run(
[str(self.neon_binpath / "local_proxy"), "--version"],
capture_output=True,
text=True,
timeout=10,
)
return result.stdout.strip()
except (subprocess.TimeoutExpired, subprocess.CalledProcessError):
return ""
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
assert self._popen and self._popen.poll() is None, (
"Local proxy exited unexpectedly. Check test log."
)
requests.get(f"http://{self.host}:{self.http_port}/metrics")
def get_metrics(self) -> str:
response = requests.get(f"http://{self.host}:{self.metrics_port}/metrics")
return response.text
def assert_no_errors(self):
# Define allowed error patterns for local_proxy
allowed_errors = [
# Add patterns as needed
]
not_allowed = [
"error",
"panic",
"failed",
]
for na in not_allowed:
if na not in allowed_errors:
assert not self.log_contains(na), f"Found disallowed error pattern: {na}"
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
self.stop()
class NeonRestBrokerProxy(LogUtils):
"""
An object managing a proxy instance configured as both auth broker and rest broker.
This is the main proxy binary with --is-auth-broker and --is-rest-broker flags.
"""
def __init__(
self,
neon_binpath: Path,
test_output_dir: Path,
wss_port: int,
http_port: int,
mgmt_port: int,
config_path: Path | None = None,
):
self.neon_binpath = neon_binpath
self.test_output_dir = test_output_dir
self.wss_port = wss_port
self.http_port = http_port
self.mgmt_port = mgmt_port
self.config_path = config_path or (test_output_dir / "rest_broker_proxy.json")
self.host = "127.0.0.1"
self.running = False
self.logfile = test_output_dir / "rest_broker_proxy.log"
self._popen: subprocess.Popen[Any] | None = None
def start(self) -> Self:
if self.running:
return self
# Generate self-signed TLS certificates
cert_path = self.test_output_dir / "server.crt"
key_path = self.test_output_dir / "server.key"
if not cert_path.exists() or not key_path.exists():
import subprocess
log.info("Generating self-signed TLS certificate for rest broker")
subprocess.run(
[
"openssl",
"req",
"-new",
"-x509",
"-days",
"365",
"-nodes",
"-text",
"-out",
str(cert_path),
"-keyout",
str(key_path),
"-subj",
"/CN=*.local.neon.build",
],
check=True,
)
log.info(
f"Starting rest broker proxy on WSS port {self.wss_port}, HTTP port {self.http_port}"
)
cmd = [
str(self.neon_binpath / "proxy"),
"-c",
str(cert_path),
"-k",
str(key_path),
"--is-auth-broker",
"true",
"--is-rest-broker",
"true",
"--wss",
f"{self.host}:{self.wss_port}",
"--http",
f"{self.host}:{self.http_port}",
"--mgmt",
f"{self.host}:{self.mgmt_port}",
"--auth-backend",
"local",
"--config-path",
str(self.config_path),
]
log.info(f"Starting rest broker proxy with command: {' '.join(cmd)}")
with open(self.logfile, "w") as logfile:
self._popen = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
cwd=self.test_output_dir,
env={
**os.environ,
"RUST_LOG": "info",
"LOGFMT": "text",
"OTEL_SDK_DISABLED": "true",
},
)
self.running = True
self._wait_until_ready()
return self
def stop(self) -> Self:
if not self.running:
return self
log.info("Stopping rest broker proxy")
if self._popen is not None:
self._popen.terminate()
try:
self._popen.wait(timeout=10)
except subprocess.TimeoutExpired:
log.warning("failed to gracefully terminate rest broker proxy; killing")
self._popen.kill()
self.running = False
return self
def get_binary_version(self) -> str:
cmd = [str(self.neon_binpath / "proxy"), "--version"]
res = subprocess.run(cmd, capture_output=True, text=True, check=True)
return res.stdout.strip()
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
# Check if the WSS port is ready using a simple HTTPS request
# REST API is served on the WSS port with HTTPS
requests.get(f"https://{self.host}:{self.wss_port}/", timeout=1, verify=False)
# Any response (even error) means the server is up - we just need to connect
def get_metrics(self) -> str:
# Metrics are still on the HTTP port
response = requests.get(f"http://{self.host}:{self.http_port}/metrics", timeout=5)
response.raise_for_status()
return response.text
def assert_no_errors(self):
# Define allowed error patterns for rest broker proxy
allowed_errors = [
"connection closed before message completed",
"connection reset by peer",
"broken pipe",
"client disconnected",
"Authentication failed",
"connection timed out",
"no connection available",
"Pool dropped",
]
with open(self.logfile) as f:
for line in f:
if "ERROR" in line or "FATAL" in line:
if not any(allowed in line for allowed in allowed_errors):
raise AssertionError(
f"Found error in rest broker proxy log: {line.strip()}"
)
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
self.stop()
@pytest.fixture(scope="function")
def link_proxy(
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
@@ -4203,6 +4497,81 @@ def static_proxy(
yield proxy
@pytest.fixture(scope="function")
def local_proxy(
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
) -> Iterator[NeonLocalProxy]:
"""Local proxy that connects directly to vanilla postgres for rest broker testing."""
# Start vanilla_pg without database bootstrapping
vanilla_pg.start()
http_port = port_distributor.get_port()
metrics_port = port_distributor.get_port()
with NeonLocalProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
http_port=http_port,
metrics_port=metrics_port,
vanilla_pg=vanilla_pg,
) as proxy:
proxy.start()
yield proxy
@pytest.fixture(scope="function")
def local_proxy_fixed_port(
vanilla_pg: VanillaPostgres,
neon_binpath: Path,
test_output_dir: Path,
) -> Iterator[NeonLocalProxy]:
"""Local proxy that connects directly to vanilla postgres on the hardcoded port 7432."""
# Start vanilla_pg without database bootstrapping
vanilla_pg.start()
# Use the hardcoded port that the rest broker proxy expects
http_port = 7432
metrics_port = 7433 # Use a different port for metrics
with NeonLocalProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
http_port=http_port,
metrics_port=metrics_port,
vanilla_pg=vanilla_pg,
) as proxy:
proxy.start()
yield proxy
@pytest.fixture(scope="function")
def rest_broker_proxy(
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
) -> Iterator[NeonRestBrokerProxy]:
"""Rest broker proxy that handles both auth broker and rest broker functionality."""
wss_port = port_distributor.get_port()
http_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
with NeonRestBrokerProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
wss_port=wss_port,
http_port=http_port,
mgmt_port=mgmt_port,
) as proxy:
proxy.start()
yield proxy
@pytest.fixture(scope="function")
def neon_authorize_jwk() -> jwk.JWK:
kid = str(uuid.uuid4())
@@ -5435,7 +5804,7 @@ SKIP_FILES = frozenset(
"postmaster.pid",
"pg_control",
"pg_dynshmem",
".metrics.socket",
"neon-communicator.socket",
)
)

View File

@@ -152,6 +152,8 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [
".*reconciler.*neon_local error.*",
# Tenant rate limits may fire in tests that submit lots of API requests.
".*tenant \\S+ is rate limited.*",
# Reconciliations may get stuck/delayed e.g. in chaos tests.
".*background_reconcile: Shard reconciliation is stuck.*",
]

View File

@@ -741,3 +741,29 @@ def shared_buffers_for_max_cu(max_cu: float) -> str:
sharedBuffersMb = int(max(128, (1023 + maxBackends * 256) / 1024))
sharedBuffers = int(sharedBuffersMb * 1024 / 8)
return str(sharedBuffers)
def skip_if_proxy_lacks_rest_broker(reason: str = "proxy was built without 'rest_broker' feature"):
# Determine the binary path using the same logic as neon_binpath fixture
def has_rest_broker_feature():
# Find the neon binaries
if env_neon_bin := os.environ.get("NEON_BIN"):
binpath = Path(env_neon_bin)
else:
base_dir = Path(__file__).parents[2] # Same as BASE_DIR in paths.py
build_type = os.environ.get("BUILD_TYPE", "debug")
binpath = base_dir / "target" / build_type
proxy_bin = binpath / "proxy"
if not proxy_bin.exists():
return False
try:
cmd = [str(proxy_bin), "--help"]
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10)
help_output = result.stdout
return "--is-rest-broker" in help_output
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
return False
return pytest.mark.skipif(not has_rest_broker_feature(), reason=reason)

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import pytest
import requests
import requests_unixsocket # type: ignore [import-untyped]
from fixtures.metrics import parse_metrics
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
NEON_COMMUNICATOR_SOCKET_NAME = "neon-communicator.socket"
def test_communicator_metrics(neon_simple_env: NeonEnv):
"""
Test the communicator's built-in HTTP prometheus exporter
"""
env = neon_simple_env
endpoint = env.endpoints.create("main")
endpoint.start()
# Change current directory to the data directory, so that we can use
# a short relative path to refer to the socket. (There's a 100 char
# limitation on the path.)
os.chdir(str(endpoint.pgdata_dir))
session = requests_unixsocket.Session()
r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/metrics")
assert r.status_code == 200, f"got response {r.status_code}: {r.text}"
# quick test that the endpoint returned something expected. (We don't validate
# that the metrics returned are sensible.)
m = parse_metrics(r.text)
m.query_one("lfc_hits")
m.query_one("lfc_misses")
# Test panic handling. The /debug/panic endpoint raises a Rust panic. It's
# expected to unwind and drop the HTTP connection without response, but not
# kill the process or the server.
with pytest.raises(
requests.ConnectionError, match="Remote end closed connection without response"
):
r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/debug/panic")
assert r.status_code == 500
# Test that subsequent requests after the panic still work.
r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/metrics")
assert r.status_code == 200, f"got response {r.status_code}: {r.text}"
m = parse_metrics(r.text)
m.query_one("lfc_hits")
m.query_one("lfc_misses")

View File

@@ -197,7 +197,7 @@ def test_create_snapshot(
shutil.copytree(
test_output_dir,
new_compatibility_snapshot_dir,
ignore=shutil.ignore_patterns("pg_dynshmem"),
ignore=shutil.ignore_patterns("pg_dynshmem", "neon-communicator.socket"),
)
log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}")

View File

@@ -0,0 +1,47 @@
import shutil
from fixtures.neon_fixtures import NeonEnvBuilder
from fixtures.utils import query_scalar
def test_hcc_handling_ps_data_loss(
neon_env_builder: NeonEnvBuilder,
):
"""
Test that following a pageserver local data loss event, the system can recover automatically (i.e.
rehydrating the restarted pageserver from remote storage) without manual intervention. The
pageserver indicates to the storage controller that it has restarted without any local tenant
data in its "reattach" request and the storage controller uses this information to detect the
data loss condition and reconfigure the pageserver as necessary.
"""
env = neon_env_builder.init_configs()
env.broker.start()
env.storage_controller.start(handle_ps_local_disk_loss=True)
env.pageserver.start()
for sk in env.safekeepers:
sk.start()
# create new nenant
tenant_id, _ = env.create_tenant(shard_count=4)
endpoint = env.endpoints.create_start("main", tenant_id=tenant_id)
with endpoint.cursor() as cur:
cur.execute("SELECT pg_logical_emit_message(false, 'neon-test', 'between inserts')")
cur.execute("CREATE DATABASE testdb")
with endpoint.cursor(dbname="testdb") as cur:
cur.execute("CREATE TABLE tbl_one_hundred_rows AS SELECT generate_series(1,100)")
endpoint.stop()
# Kill the pageserver, remove the `tenants/` directory, and restart. This simulates a pageserver
# that restarted with the same ID but has lost all its local disk data.
env.pageserver.stop(immediate=True)
shutil.rmtree(env.pageserver.tenant_dir())
env.pageserver.start()
# Test that the endpoint can start and query the database after the pageserver restarts. This
# indirectly tests that the pageserver was able to rehydrate the tenant data it lost from remote
# storage automatically.
endpoint.start()
with endpoint.cursor(dbname="testdb") as cur:
assert query_scalar(cur, "SELECT count(*) FROM tbl_one_hundred_rows") == 100

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
import pytest
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.utils import USE_LFC, query_scalar
if TYPE_CHECKING:
@@ -75,10 +76,24 @@ WITH (fillfactor='100');
cur.execute("SELECT abalance FROM pgbench_accounts WHERE aid = 104242")
cur.execute("SELECT abalance FROM pgbench_accounts WHERE aid = 204242")
# verify working set size after some index access of a few select pages only
blocks = query_scalar(cur, "select approximate_working_set_size(true)")
blocks = query_scalar(cur, "select approximate_working_set_size(false)")
log.info(f"working set size after some index access of a few select pages only {blocks}")
assert blocks < 20
# Also test the metrics from the /autoscaling_metrics endpoint
autoscaling_metrics = endpoint.http_client().autoscaling_metrics()
log.debug(f"Raw metrics: {autoscaling_metrics}")
m = parse_metrics(autoscaling_metrics)
http_estimate = m.query_one(
"lfc_approximate_working_set_size_windows",
{
"duration_seconds": "60",
},
).value
log.info(f"http estimate: {http_estimate}, blocks: {blocks}")
assert http_estimate > 0 and http_estimate < 20
@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping")
def test_sliding_working_set_approximation(neon_simple_env: NeonEnv):

View File

@@ -3,6 +3,7 @@
#
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, cast
@@ -356,6 +357,81 @@ def test_sql_regress(
post_checks(env, test_output_dir, DBNAME, endpoint)
def test_max_wal_rate(neon_simple_env: NeonEnv):
"""
Test the databricks.max_wal_mb_per_second GUC and how it affects WAL rate
limiting.
"""
env = neon_simple_env
DBNAME = "regression"
superuser_name = "databricks_superuser"
# Connect to postgres and create a database called "regression".
endpoint = env.endpoints.create_start(
"main",
config_lines=[
# we need this option because default max_cluster_size < 0 will disable throttling completely
"neon.max_cluster_size=10GB",
],
)
endpoint.safe_psql_many(
[
f"CREATE ROLE {superuser_name}",
f"CREATE DATABASE {DBNAME}",
"CREATE EXTENSION neon",
]
)
endpoint.safe_psql("CREATE TABLE usertable (YCSB_KEY INT, FIELD0 TEXT);", dbname=DBNAME)
# Write ~1 MB data.
with endpoint.cursor(dbname=DBNAME) as cur:
for _ in range(0, 1000):
cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);")
# No backpressure
tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();")
assert tuples[0][0] == 0, "Backpressure throttling detected"
# 0 MB/s max_wal_rate. WAL proposer can still push some WALs but will be super slow.
endpoint.safe_psql_many(
[
"ALTER SYSTEM SET databricks.max_wal_mb_per_second = 0;",
"SELECT pg_reload_conf();",
]
)
# Write ~10 KB data should hit backpressure.
with endpoint.cursor(dbname=DBNAME) as cur:
cur.execute("SET databricks.max_wal_mb_per_second = 0;")
for _ in range(0, 10):
cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);")
tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();")
assert tuples[0][0] > 0, "No backpressure throttling detected"
# 1 MB/s max_wal_rate.
endpoint.safe_psql_many(
[
"ALTER SYSTEM SET databricks.max_wal_mb_per_second = 1;",
"SELECT pg_reload_conf();",
]
)
# Write 10 MB data.
with endpoint.cursor(dbname=DBNAME) as cur:
start = int(time.time())
for _ in range(0, 10000):
cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);")
end = int(time.time())
assert end - start >= 10, (
"Throttling should cause the previous inserts to take greater than or equal to 10 seconds"
)
@skip_in_debug_build("only run with release build")
@pytest.mark.parametrize("reldir_type", ["v1", "v2"])
def test_tx_abort_with_many_relations(

View File

@@ -0,0 +1,137 @@
import json
import signal
import time
import requests
from fixtures.utils import skip_if_proxy_lacks_rest_broker
from jwcrypto import jwt
@skip_if_proxy_lacks_rest_broker()
def test_rest_broker_happy(
local_proxy_fixed_port, rest_broker_proxy, vanilla_pg, neon_authorize_jwk, httpserver
):
"""Test REST API endpoint using local_proxy and rest_broker_proxy."""
# Use the fixed port local proxy
local_proxy = local_proxy_fixed_port
# Create the required roles for PostgREST authentication
vanilla_pg.safe_psql("CREATE ROLE authenticator LOGIN")
vanilla_pg.safe_psql("CREATE ROLE authenticated")
vanilla_pg.safe_psql("CREATE ROLE anon")
vanilla_pg.safe_psql("GRANT authenticated TO authenticator")
vanilla_pg.safe_psql("GRANT anon TO authenticator")
# Create the pgrst schema and configuration function required by the rest broker
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS pgrst")
vanilla_pg.safe_psql("""
CREATE OR REPLACE FUNCTION pgrst.pre_config()
RETURNS VOID AS $$
SELECT
set_config('pgrst.db_schemas', 'test', true)
, set_config('pgrst.db_aggregates_enabled', 'true', true)
, set_config('pgrst.db_anon_role', 'anon', true)
, set_config('pgrst.jwt_aud', '', true)
, set_config('pgrst.jwt_secret', '', true)
, set_config('pgrst.jwt_role_claim_key', '."role"', true)
$$ LANGUAGE SQL;
""")
vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA pgrst TO authenticator")
vanilla_pg.safe_psql("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgrst TO authenticator")
# Bootstrap the database with test data
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS test")
vanilla_pg.safe_psql("""
CREATE TABLE IF NOT EXISTS test.items (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL
)
""")
vanilla_pg.safe_psql("INSERT INTO test.items (name) VALUES ('test_item')")
# Grant access to the test schema for the authenticated role
vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA test TO authenticated")
vanilla_pg.safe_psql("GRANT SELECT ON ALL TABLES IN SCHEMA test TO authenticated")
# Set up HTTP server to serve JWKS (like static_auth_broker)
# Generate public key from the JWK
public_key = neon_authorize_jwk.export_public(as_dict=True)
# Set up the httpserver to serve the JWKS
httpserver.expect_request("/.well-known/jwks.json").respond_with_json({"keys": [public_key]})
# Create JWKS configuration for the rest broker proxy
jwks_config = {
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": httpserver.url_for("/.well-known/jwks.json"),
"provider_name": "foo",
"jwt_audience": None,
}
]
}
# Write the JWKS config to the config file that rest_broker_proxy expects
config_file = rest_broker_proxy.config_path
with open(config_file, "w") as f:
json.dump(jwks_config, f)
# Write the same config to the local_proxy config file
local_config_file = local_proxy.config_path
with open(local_config_file, "w") as f:
json.dump(jwks_config, f)
# Signal both proxies to reload their config
if rest_broker_proxy._popen is not None:
rest_broker_proxy._popen.send_signal(signal.SIGHUP)
if local_proxy._popen is not None:
local_proxy._popen.send_signal(signal.SIGHUP)
# Wait a bit for config to reload
time.sleep(0.5)
# Generate a proper JWT token using the JWK (similar to test_auth_broker.py)
token = jwt.JWT(
header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"},
claims={
"sub": "user",
"role": "authenticated", # role that's in role_names
"exp": 9999999999, # expires far in the future
"iat": 1000000000, # issued at
},
)
token.make_signed_token(neon_authorize_jwk)
# Debug: Print the JWT claims and config for troubleshooting
print(f"JWT claims: {token.claims}")
print(f"JWT header: {token.header}")
print(f"Config file contains: {jwks_config}")
print(f"Public key kid: {public_key.get('kid')}")
# Test REST API call - following SUBZERO.md pattern
# REST API is served on the WSS port with HTTPS and includes database name
# ep-purple-glitter-adqior4l-pooler.c-2.us-east-1.aws.neon.tech
url = f"https://foo.apirest.c-2.local.neon.build:{rest_broker_proxy.wss_port}/postgres/rest/v1/items"
response = requests.get(
url,
headers={
"Authorization": f"Bearer {token.serialize()}",
},
params={"id": "eq.1", "select": "name"},
verify=False, # Skip SSL verification for self-signed certs
)
print(f"Response status: {response.status_code}")
print(f"Response headers: {response.headers}")
print(f"Response body: {response.text}")
# For now, let's just check that we get some response
# We can refine the assertions once we see what the actual response looks like
assert response.status_code in [200] # Any response means the proxies are working
# check the response body
assert response.json() == [{"name": "test_item"}]

View File

@@ -3,11 +3,22 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
import requests
from fixtures.log_helper import log
from fixtures.neon_fixtures import StorageControllerApiException
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnvBuilder
# TODO(diko): pageserver spams with various errors during safekeeper migration.
# Fix the code so it handles the migration better.
ALLOWED_PAGESERVER_ERRORS = [
".*Timeline .* was cancelled and cannot be used anymore.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was not found in global map.*",
".*wal receiver task finished with an error.*",
]
def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
"""
@@ -24,16 +35,7 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
"timeline_safekeeper_count": 1,
}
env = neon_env_builder.init_start()
# TODO(diko): pageserver spams with various errors during safekeeper migration.
# Fix the code so it handles the migration better.
env.pageserver.allowed_errors.extend(
[
".*Timeline .* was cancelled and cannot be used anymore.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was not found in global map.*",
".*wal receiver task finished with an error.*",
]
)
env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS)
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
@@ -42,15 +44,23 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
assert len(mconf["sk_set"]) == 1
assert mconf["generation"] == 1
current_sk = mconf["sk_set"][0]
ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"])
ep.safe_psql("CREATE EXTENSION neon_test_utils;")
ep.safe_psql("CREATE TABLE t(a int)")
expected_gen = 1
for active_sk in range(1, 4):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [active_sk]
)
if active_sk != current_sk:
expected_gen += 2
current_sk = active_sk
other_sks = [sk for sk in range(1, 4) if sk != active_sk]
for sk in other_sks:
@@ -65,9 +75,6 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(1, 4)]
# 1 initial generation + 2 migrations on each loop iteration.
expected_gen = 1 + 2 * 3
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["generation"] == expected_gen
@@ -113,3 +120,79 @@ def test_new_sk_set_validation(neon_env_builder: NeonEnvBuilder):
env.storage_controller.safekeeper_scheduling_policy(decom_sk, "Decomissioned")
expect_fail([sk_set[0], decom_sk], "decomissioned")
def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBuilder):
"""
Test that safekeeper migration handles failures well.
Two main conditions are checked:
1. safekeeper migration handler can be retried on different failures.
2. writes do not stuck if sk_set and new_sk_set have a quorum in common.
"""
neon_env_builder.num_safekeepers = 4
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 3,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert len(mconf["sk_set"]) == 3
assert mconf["generation"] == 1
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"])
ep.safe_psql("CREATE EXTENSION neon_test_utils;")
ep.safe_psql("CREATE TABLE t(a int)")
excluded_sk = mconf["sk_set"][-1]
added_sk = [sk.id for sk in env.safekeepers if sk.id not in mconf["sk_set"]][0]
new_sk_set = mconf["sk_set"][:-1] + [added_sk]
log.info(f"migrating sk set from {mconf['sk_set']} to {new_sk_set}")
failpoints = [
"sk-migration-after-step-3",
"sk-migration-after-step-4",
"sk-migration-after-step-5",
"sk-migration-after-step-7",
"sk-migration-after-step-8",
"sk-migration-step-9-after-set-membership",
"sk-migration-step-9-mid-exclude",
"sk-migration-step-9-after-exclude",
"sk-migration-after-step-9",
]
for i, fp in enumerate(failpoints):
env.storage_controller.configure_failpoints((fp, "return(1)"))
with pytest.raises(StorageControllerApiException, match=f"failpoint {fp}"):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, new_sk_set
)
ep.safe_psql(f"INSERT INTO t VALUES ({i})")
env.storage_controller.configure_failpoints((fp, "off"))
# No failpoints, migration should succeed.
env.storage_controller.migrate_safekeepers(env.initial_tenant, env.initial_timeline, new_sk_set)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["new_sk_set"] is None
assert mconf["sk_set"] == new_sk_set
assert mconf["generation"] == 3
ep.clear_buffers()
assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(len(failpoints))]
assert ep.safe_psql("SHOW neon.safekeepers")[0][0].startswith("g#3:")
# Check that we didn't forget to remove the timeline on the excluded safekeeper.
with pytest.raises(requests.exceptions.HTTPError) as exc:
env.safekeepers[excluded_sk - 1].http_client().timeline_status(
env.initial_tenant, env.initial_timeline
)
assert exc.value.response.status_code == 404
assert (
f"timeline {env.initial_tenant}/{env.initial_timeline} deleted" in exc.value.response.text
)

View File

@@ -1810,6 +1810,8 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder):
"config_lines": [
# Tip: set to 100MB to make the test fail
"max_replication_write_lag=1MB",
# Hadron: Need to set max_cluster_size to some value to enable any backpressure at all.
"neon.max_cluster_size=1GB",
],
# We need `neon` extension for calling backpressure functions,
# this flag instructs `compute_ctl` to pre-install it.