Fix regress tests: set NEON_INTERNAL_CA_FILE

This commit is contained in:
Ivan Efremov
2025-05-12 17:48:06 +03:00
parent 8c2ba51dce
commit 067f80e8db
3 changed files with 111 additions and 39 deletions

View File

@@ -15,7 +15,7 @@ import time
import uuid
from collections import defaultdict
from collections.abc import Mapping
from contextlib import closing, contextmanager
from contextlib import AbstractContextManager, closing, contextmanager, nullcontext
from dataclasses import dataclass
from datetime import datetime
from enum import StrEnum
@@ -3290,17 +3290,35 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) -
# TODO make port an optional argument
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True):
cert_path: Path | None
def __init__(
self,
pgdatadir: Path,
pg_bin: PgBin,
port: int,
cert_data_dir: Path | None = None,
init: bool = True,
):
super().__init__(host="localhost", port=port, dbname="postgres")
self.pgdatadir = pgdatadir
self.cert_path = cert_data_dir
self.pg_bin = pg_bin
self.running = False
if init:
self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)])
self.configure([f"port = {port}\n"])
if cert_data_dir is not None:
self.create_certs()
def create_certs(self):
if self.cert_path is None:
raise ValueError("cert_path for VanillaPostgres is not set")
cert_file = self.cert_path / "server.crt"
key_file = self.cert_path / "server.key"
common_name = "*.local.neon.build"
def enable_tls(self):
assert not self.running
# generate self-signed certificate
subprocess.run(
[
@@ -3313,19 +3331,37 @@ class VanillaPostgres(PgProtocol):
"-nodes",
"-text",
"-out",
self.pgdatadir / "server.crt",
cert_file,
"-keyout",
self.pgdatadir / "server.key",
key_file,
"-subj",
"/CN=localhost",
f"/CN={common_name}",
"-addext",
"basicConstraints=critical,CA:FALSE",
"-addext",
"keyUsage=digitalSignature,keyEncipherment",
"-addext",
"extendedKeyUsage=serverAuth",
"-addext",
"subjectAltName=DNS:localhost,IP:127.0.0.1",
]
)
def enable_tls(self):
assert not self.running
if self.cert_path is None:
raise ValueError("cert_path for VanillaPostgres is not set")
cert_file = self.cert_path / "server.crt"
key_file = self.cert_path / "server.key"
# configure postgresql.conf
self.configure(
[
"ssl = on",
"ssl_cert_file = 'server.crt'",
"ssl_key_file = 'server.key'",
f"ssl_cert_file = '{cert_file}'",
f"ssl_key_file = '{key_file}'",
]
)
@@ -3390,7 +3426,7 @@ def vanilla_pg(
pgdatadir = test_output_dir / "pgdata-vanilla"
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
port = port_distributor.get_port()
with VanillaPostgres(pgdatadir, pg_bin, port) as vanilla_pg:
with VanillaPostgres(pgdatadir, pg_bin, port, test_output_dir) as vanilla_pg:
vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
yield vanilla_pg
@@ -3532,7 +3568,7 @@ def generate_proxy_tls_certs(common_name: str, key_path: Path, crt_path: Path):
"-subj",
f"/CN={common_name}",
"-addext",
f"subjectAltName = DNS:{common_name}",
f"subjectAltName=DNS:localhost,IP:127.0.0.1,DNS:{common_name}",
]
)
assert r.returncode == 0
@@ -3952,6 +3988,19 @@ class NeonAuthBroker:
self._popen.kill()
@contextmanager
def temp_env(key: str, val: str):
prev = os.environ.get(key)
os.environ[key] = val
try:
yield
finally:
if prev is None:
del os.environ[key]
else:
os.environ[key] = prev
@pytest.fixture(scope="function")
def link_proxy(
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
@@ -3964,18 +4013,28 @@ def link_proxy(
external_http_port = port_distributor.get_port()
router_port = port_distributor.get_port()
router_tls_port = port_distributor.get_port()
compute_cert_path = test_output_dir / "server.crt"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
router_port=router_port,
router_tls_port=router_tls_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Link(),
) as proxy:
ca_ctx: AbstractContextManager[None]
if compute_cert_path.exists():
ca_ctx = temp_env("NEON_INTERNAL_CA_FILE", str(compute_cert_path))
else:
ca_ctx = nullcontext()
with (
ca_ctx,
NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
router_port=router_port,
router_tls_port=router_tls_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Link(),
) as proxy,
):
proxy.start()
yield proxy

View File

@@ -546,7 +546,7 @@ def test_fast_import_with_pageserver_ingest(
# Sanity check that data in pgdata is expected:
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
with VanillaPostgres(
fast_import.workdir / "pgdata", pgbin, pg_port, False
fast_import.workdir / "pgdata", pgbin, pg_port, None, False
) as new_pgdata_vanilla_pg:
new_pgdata_vanilla_pg.start()
@@ -629,7 +629,7 @@ def test_fast_import_binary(
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
with VanillaPostgres(
fast_import.workdir / "pgdata", pgbin, pg_port, False
fast_import.workdir / "pgdata", pgbin, pg_port, None, False
) as new_pgdata_vanilla_pg:
new_pgdata_vanilla_pg.start()
@@ -678,7 +678,7 @@ def test_fast_import_event_triggers(
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
with VanillaPostgres(
fast_import.workdir / "pgdata", pgbin, pg_port, False
fast_import.workdir / "pgdata", pgbin, pg_port, None, False
) as new_pgdata_vanilla_pg:
new_pgdata_vanilla_pg.start()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING
import pytest
@@ -8,6 +9,7 @@ from fixtures.neon_fixtures import (
PSQL,
NeonProxy,
VanillaPostgres,
temp_env,
)
from werkzeug.wrappers.response import Response
@@ -59,19 +61,30 @@ def proxy_with_metric_collector(
metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
metric_collection_interval = "5s"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
router_port=router_port,
router_tls_port=router_tls_port,
external_http_port=external_http_port,
metric_collection_endpoint=metric_collection_endpoint,
metric_collection_interval=metric_collection_interval,
auth_backend=NeonProxy.Link(),
) as proxy:
compute_cert_path = test_output_dir / "server.crt"
ca_ctx: AbstractContextManager[None]
if compute_cert_path.exists():
ca_ctx = temp_env("NEON_INTERNAL_CA_FILE", str(compute_cert_path))
else:
ca_ctx = nullcontext()
with (
ca_ctx,
NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
router_port=router_port,
router_tls_port=router_tls_port,
external_http_port=external_http_port,
metric_collection_endpoint=metric_collection_endpoint,
metric_collection_interval=metric_collection_interval,
auth_backend=NeonProxy.Link(),
) as proxy,
):
proxy.start()
yield proxy
@@ -79,8 +92,8 @@ def proxy_with_metric_collector(
@pytest.mark.asyncio
async def test_proxy_metric_collection(
httpserver: HTTPServer,
vanilla_pg: VanillaPostgres, # we should create compute certificates first
proxy_with_metric_collector: NeonProxy,
vanilla_pg: VanillaPostgres,
):
# mock http server that returns OK for the metrics
httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(