diff --git a/test_runner/regress/test_sni_router.py b/test_runner/regress/test_sni_router.py new file mode 100644 index 0000000000..334b587c38 --- /dev/null +++ b/test_runner/regress/test_sni_router.py @@ -0,0 +1,139 @@ +import socket +import subprocess +from pathlib import Path +from types import TracebackType +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast + +import backoff # type: ignore +import psycopg2 +import pytest +from fixtures.log_helper import log +from fixtures.neon_fixtures import PgProtocol, PortDistributor, VanillaPostgres + + +def generate_tls_cert(cn, certout, keyout): + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-out", + certout, + "-keyout", + keyout, + "-subj", + f"/CN={cn}", + ] + ) + + +class PgSniRouter(PgProtocol): + def __init__( + self, + neon_binpath: Path, + port: int, + destination: str, + destination_port: int, + tls_cert: Path, + tls_key: Path, + ): + # Must use a hostname rather than IP here, for SNI to work + host = "localhost" + super().__init__(host=host, port=port) + + self.host = host + self.neon_binpath = neon_binpath + self.port = port + self.destination = destination + self.destination_port = destination_port + self.tls_cert = tls_cert + self.tls_key = tls_key + self._popen: Optional[subprocess.Popen[bytes]] = None + + def start(self) -> "PgSniRouter": + assert self._popen is None + args = [ + str(self.neon_binpath / "pg_sni_router"), + *["--listen", f"127.0.0.1:{self.port}"], + *["--tls-cert", self.tls_cert], + *["--tls-key", self.tls_key], + *["--destination", self.destination], + *["--destination-port", str(self.destination_port)], + ] + + self._popen = subprocess.Popen(args) + self._wait_until_ready() + return self + + @backoff.on_exception(backoff.expo, OSError, max_time=10) + def _wait_until_ready(self): + socket.create_connection((self.host, self.port)) + + # Sends SIGTERM to the proxy if it has been started + def terminate(self): + if self._popen: + self._popen.terminate() + + # Waits for proxy to exit if it has been opened with a default timeout of + # two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time. + def wait_for_exit(self, timeout=2): + if self._popen: + self._popen.wait(timeout=2) + + def __enter__(self) -> "PgSniRouter": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warn("failed to gracefully terminate pg_sni_router; killing") + self._popen.kill() + + +def test_pg_sni_router( + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +): + + generate_tls_cert( + "external.test", test_output_dir / "router.crt", test_output_dir / "router.key" + ) + + # Start a stand-alone Postgres to test with + vanilla_pg.start() + pg_port = vanilla_pg.default_options["port"] + + router_port = port_distributor.get_port() + + with PgSniRouter( + neon_binpath=neon_binpath, + port=router_port, + destination="localhost", + destination_port=pg_port, + tls_cert=test_output_dir / "router.crt", + tls_key=test_output_dir / "router.key", + ) as router: + router.start() + + out = router.safe_psql( + "select 1", + dbname="postgres", + sslmode="require", + host="localhost.external.test", + hostaddr="127.0.0.1", + ) + assert out[0][0] == 1