From 067f80e8db55c14e3f00012e9b282e13cf0758a8 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Mon, 12 May 2025 17:48:06 +0300 Subject: [PATCH] Fix regress tests: set NEON_INTERNAL_CA_FILE --- test_runner/fixtures/neon_fixtures.py | 103 ++++++++++++++---- test_runner/regress/test_import_pgdata.py | 6 +- .../regress/test_proxy_metric_collection.py | 41 ++++--- 3 files changed, 111 insertions(+), 39 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 2801a0e867..215d98fd9e 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -15,7 +15,7 @@ import time import uuid from collections import defaultdict from collections.abc import Mapping -from contextlib import closing, contextmanager +from contextlib import AbstractContextManager, closing, contextmanager, nullcontext from dataclasses import dataclass from datetime import datetime from enum import StrEnum @@ -3290,17 +3290,35 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) - # TODO make port an optional argument class VanillaPostgres(PgProtocol): - def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True): + cert_path: Path | None + + def __init__( + self, + pgdatadir: Path, + pg_bin: PgBin, + port: int, + cert_data_dir: Path | None = None, + init: bool = True, + ): super().__init__(host="localhost", port=port, dbname="postgres") self.pgdatadir = pgdatadir + self.cert_path = cert_data_dir self.pg_bin = pg_bin self.running = False if init: self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)]) self.configure([f"port = {port}\n"]) + if cert_data_dir is not None: + self.create_certs() + + def create_certs(self): + if self.cert_path is None: + raise ValueError("cert_path for VanillaPostgres is not set") + + cert_file = self.cert_path / "server.crt" + key_file = self.cert_path / "server.key" + common_name = "*.local.neon.build" - def enable_tls(self): - assert not self.running # generate self-signed certificate subprocess.run( [ @@ -3313,19 +3331,37 @@ class VanillaPostgres(PgProtocol): "-nodes", "-text", "-out", - self.pgdatadir / "server.crt", + cert_file, "-keyout", - self.pgdatadir / "server.key", + key_file, "-subj", - "/CN=localhost", + f"/CN={common_name}", + "-addext", + "basicConstraints=critical,CA:FALSE", + "-addext", + "keyUsage=digitalSignature,keyEncipherment", + "-addext", + "extendedKeyUsage=serverAuth", + "-addext", + "subjectAltName=DNS:localhost,IP:127.0.0.1", ] ) + + def enable_tls(self): + assert not self.running + + if self.cert_path is None: + raise ValueError("cert_path for VanillaPostgres is not set") + + cert_file = self.cert_path / "server.crt" + key_file = self.cert_path / "server.key" + # configure postgresql.conf self.configure( [ "ssl = on", - "ssl_cert_file = 'server.crt'", - "ssl_key_file = 'server.key'", + f"ssl_cert_file = '{cert_file}'", + f"ssl_key_file = '{key_file}'", ] ) @@ -3390,7 +3426,7 @@ def vanilla_pg( pgdatadir = test_output_dir / "pgdata-vanilla" pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version) port = port_distributor.get_port() - with VanillaPostgres(pgdatadir, pg_bin, port) as vanilla_pg: + with VanillaPostgres(pgdatadir, pg_bin, port, test_output_dir) as vanilla_pg: vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"]) yield vanilla_pg @@ -3532,7 +3568,7 @@ def generate_proxy_tls_certs(common_name: str, key_path: Path, crt_path: Path): "-subj", f"/CN={common_name}", "-addext", - f"subjectAltName = DNS:{common_name}", + f"subjectAltName=DNS:localhost,IP:127.0.0.1,DNS:{common_name}", ] ) assert r.returncode == 0 @@ -3952,6 +3988,19 @@ class NeonAuthBroker: self._popen.kill() +@contextmanager +def temp_env(key: str, val: str): + prev = os.environ.get(key) + os.environ[key] = val + try: + yield + finally: + if prev is None: + del os.environ[key] + else: + os.environ[key] = prev + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -3964,18 +4013,28 @@ def link_proxy( external_http_port = port_distributor.get_port() router_port = port_distributor.get_port() router_tls_port = port_distributor.get_port() + compute_cert_path = test_output_dir / "server.crt" - with NeonProxy( - neon_binpath=neon_binpath, - test_output_dir=test_output_dir, - proxy_port=proxy_port, - http_port=http_port, - mgmt_port=mgmt_port, - router_port=router_port, - router_tls_port=router_tls_port, - external_http_port=external_http_port, - auth_backend=NeonProxy.Link(), - ) as proxy: + ca_ctx: AbstractContextManager[None] + if compute_cert_path.exists(): + ca_ctx = temp_env("NEON_INTERNAL_CA_FILE", str(compute_cert_path)) + else: + ca_ctx = nullcontext() + + with ( + ca_ctx, + NeonProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + router_port=router_port, + router_tls_port=router_tls_port, + external_http_port=external_http_port, + auth_backend=NeonProxy.Link(), + ) as proxy, + ): proxy.start() yield proxy diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 05e63ad955..5daea26ec7 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -546,7 +546,7 @@ def test_fast_import_with_pageserver_ingest( # Sanity check that data in pgdata is expected: pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) with VanillaPostgres( - fast_import.workdir / "pgdata", pgbin, pg_port, False + fast_import.workdir / "pgdata", pgbin, pg_port, None, False ) as new_pgdata_vanilla_pg: new_pgdata_vanilla_pg.start() @@ -629,7 +629,7 @@ def test_fast_import_binary( pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) with VanillaPostgres( - fast_import.workdir / "pgdata", pgbin, pg_port, False + fast_import.workdir / "pgdata", pgbin, pg_port, None, False ) as new_pgdata_vanilla_pg: new_pgdata_vanilla_pg.start() @@ -678,7 +678,7 @@ def test_fast_import_event_triggers( pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) with VanillaPostgres( - fast_import.workdir / "pgdata", pgbin, pg_port, False + fast_import.workdir / "pgdata", pgbin, pg_port, None, False ) as new_pgdata_vanilla_pg: new_pgdata_vanilla_pg.start() diff --git a/test_runner/regress/test_proxy_metric_collection.py b/test_runner/regress/test_proxy_metric_collection.py index 7442d50f68..561da81bf0 100644 --- a/test_runner/regress/test_proxy_metric_collection.py +++ b/test_runner/regress/test_proxy_metric_collection.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING import pytest @@ -8,6 +9,7 @@ from fixtures.neon_fixtures import ( PSQL, NeonProxy, VanillaPostgres, + temp_env, ) from werkzeug.wrappers.response import Response @@ -59,19 +61,30 @@ def proxy_with_metric_collector( metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" metric_collection_interval = "5s" - with NeonProxy( - neon_binpath=neon_binpath, - test_output_dir=test_output_dir, - proxy_port=proxy_port, - http_port=http_port, - mgmt_port=mgmt_port, - router_port=router_port, - router_tls_port=router_tls_port, - external_http_port=external_http_port, - metric_collection_endpoint=metric_collection_endpoint, - metric_collection_interval=metric_collection_interval, - auth_backend=NeonProxy.Link(), - ) as proxy: + compute_cert_path = test_output_dir / "server.crt" + + ca_ctx: AbstractContextManager[None] + if compute_cert_path.exists(): + ca_ctx = temp_env("NEON_INTERNAL_CA_FILE", str(compute_cert_path)) + else: + ca_ctx = nullcontext() + + with ( + ca_ctx, + NeonProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + router_port=router_port, + router_tls_port=router_tls_port, + external_http_port=external_http_port, + metric_collection_endpoint=metric_collection_endpoint, + metric_collection_interval=metric_collection_interval, + auth_backend=NeonProxy.Link(), + ) as proxy, + ): proxy.start() yield proxy @@ -79,8 +92,8 @@ def proxy_with_metric_collector( @pytest.mark.asyncio async def test_proxy_metric_collection( httpserver: HTTPServer, + vanilla_pg: VanillaPostgres, # we should create compute certificates first proxy_with_metric_collector: NeonProxy, - vanilla_pg: VanillaPostgres, ): # mock http server that returns OK for the metrics httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(