diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 8c1f252e0b..8b10bee5a3 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3095,6 +3095,31 @@ class PSQL: ) +def generate_proxy_tls_certs(common_name: str, key_path: Path, crt_path: Path): + if not key_path.exists(): + r = subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(crt_path), + "-keyout", + str(key_path), + "-subj", + f"/CN={common_name}", + "-addext", + f"subjectAltName = DNS:{common_name}", + ] + ) + assert r.returncode == 0 + + class NeonProxy(PgProtocol): link_auth_uri: str = "http://dummy-uri" @@ -3142,20 +3167,6 @@ class NeonProxy(PgProtocol): ] return args - class AuthBroker(AuthBackend): - def __init__(self, endpoint: str): - self.endpoint = endpoint - - def extra_args(self) -> list[str]: - args = [ - # Console auth backend params - *["--auth-backend", "console"], - *["--auth-endpoint", self.endpoint], - *["--sql-over-http-pool-opt-in", "false"], - *["--is-auth-broker", "true"], - ] - return args - @dataclass(frozen=True) class Postgres(AuthBackend): pg_conn_url: str @@ -3207,29 +3218,7 @@ class NeonProxy(PgProtocol): # generate key of it doesn't exist crt_path = self.test_output_dir / "proxy.crt" key_path = self.test_output_dir / "proxy.key" - - if not key_path.exists(): - r = subprocess.run( - [ - "openssl", - "req", - "-new", - "-x509", - "-days", - "365", - "-nodes", - "-text", - "-out", - str(crt_path), - "-keyout", - str(key_path), - "-subj", - "/CN=*.localtest.me", - "-addext", - "subjectAltName = DNS:*.localtest.me", - ] - ) - assert r.returncode == 0 + generate_proxy_tls_certs("*.localtest.me", key_path, crt_path) args = [ str(self.neon_binpath / "proxy"), @@ -3328,29 +3317,6 @@ class NeonProxy(PgProtocol): assert response.status_code == expected_code, f"response: {response.json()}" return response.json() - async def auth_broker_query(self, query, args, **kwargs): - # TODO maybe use default values if not provided - user = kwargs["user"] - token = kwargs["token"] - expected_code = kwargs.get("expected_code") - - log.info(f"Executing http query: {query}") - - connstr = f"postgresql://{user}@{self.domain}:{self.proxy_port}/postgres" - async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client: - response = await client.post( - f"https://{self.domain}:{self.external_http_port}/sql", - json={"query": query, "params": args}, - headers={ - "Neon-Connection-String": connstr, - "Authorization": f"Bearer {token}", - }, - ) - - if expected_code is not None: - assert response.status_code == expected_code, f"response: {response.json()}" - return response.json() - def get_metrics(self) -> str: request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics") return request_result.text @@ -3432,6 +3398,125 @@ class NeonProxy(PgProtocol): assert out == "ok" +class NeonAuthBroker: + class ControlPlane: + def __init__(self, endpoint: str): + self.endpoint = endpoint + + def extra_args(self) -> list[str]: + args = [ + *["--auth-backend", "console"], + *["--auth-endpoint", self.endpoint], + ] + return args + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + http_port: int, + mgmt_port: int, + external_http_port: int, + auth_backend: NeonAuthBroker.ControlPlane, + ): + self.domain = "apiauth.localtest.me" # resolves to 127.0.0.1 + self.host = "127.0.0.1" + self.http_port = http_port + self.external_http_port = external_http_port + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.mgmt_port = mgmt_port + self.auth_backend = auth_backend + self.http_timeout_seconds = 15 + self._popen: Optional[subprocess.Popen[bytes]] = None + + def start(self) -> NeonAuthBroker: + assert self._popen is None + + # generate key of it doesn't exist + crt_path = self.test_output_dir / "proxy.crt" + key_path = self.test_output_dir / "proxy.key" + generate_proxy_tls_certs("apiauth.localtest.me", key_path, crt_path) + + args = [ + str(self.neon_binpath / "proxy"), + *["--http", f"{self.host}:{self.http_port}"], + *["--mgmt", f"{self.host}:{self.mgmt_port}"], + *["--wss", f"{self.host}:{self.external_http_port}"], + *["-c", str(crt_path)], + *["-k", str(key_path)], + *["--sql-over-http-pool-opt-in", "false"], + *["--is-auth-broker", "true"], + *self.auth_backend.extra_args(), + ] + + logfile = open(self.test_output_dir / "proxy.log", "w") + self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile) + self._wait_until_ready() + return self + + # Sends SIGTERM to the proxy if it has been started + def terminate(self): + if self._popen: + self._popen.terminate() + + # Waits for proxy to exit if it has been opened with a default timeout of + # two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time. + def wait_for_exit(self, timeout=2): + if self._popen: + self._popen.wait(timeout=timeout) + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + assert ( + self._popen and self._popen.poll() is None + ), "Proxy exited unexpectedly. Check test log." + requests.get(f"http://{self.host}:{self.http_port}/v1/status") + + async def query(self, query, args, **kwargs): + user = kwargs["user"] + token = kwargs["token"] + expected_code = kwargs.get("expected_code") + + log.info(f"Executing http query: {query}") + + connstr = f"postgresql://{user}@{self.domain}/postgres" + async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client: + response = await client.post( + f"https://{self.domain}:{self.external_http_port}/sql", + json={"query": query, "params": args}, + headers={ + "Neon-Connection-String": connstr, + "Authorization": f"Bearer {token}", + }, + ) + + if expected_code is not None: + assert response.status_code == expected_code, f"response: {response.json()}" + return response.json() + + def get_metrics(self) -> str: + request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics") + return request_result.text + + def __enter__(self) -> NeonAuthBroker: + return self + + def __exit__( + self, + _exc_type: Optional[type[BaseException]], + _exc_value: Optional[BaseException], + _traceback: Optional[TracebackType], + ): + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate proxy; killing") + self._popen.kill() + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -3497,10 +3582,10 @@ def static_proxy( @pytest.fixture(scope="function") -def neon_authorize_jwk() -> Iterator[jwk.JWK]: +def neon_authorize_jwk() -> jwk.JWK: kid = str(uuid.uuid4()) key = jwk.JWK.generate(kty="RSA", size=2048, alg="RS256", use="sig", kid=kid) - yield key + return key @pytest.fixture(scope="function") @@ -3511,14 +3596,12 @@ def static_auth_broker( httpserver: HTTPServer, neon_authorize_jwk: jwk.JWK, http2_echoserver: H2Server, -) -> Iterable[NeonProxy]: - """Neon proxy that routes directly to vanilla postgres.""" +) -> Iterable[NeonAuthBroker]: + """Neon Auth Broker that routes to a mocked local_proxy and a mocked cplane HTTP API.""" - # local_proxy_endpoint = httpserver.url_for("/sql") - # local_proxy_addr = local_proxy_endpoint.removeprefix("http://").removesuffix("/sql") local_proxy_addr = f"{http2_echoserver.host}:{http2_echoserver.port}" - log.info(f"local_proxy {local_proxy_addr}") + # return local_proxy addr on ProxyWakeCompute. httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json( { "address": local_proxy_addr, @@ -3529,6 +3612,8 @@ def static_auth_broker( }, } ) + + # return jwks mock addr on GetEndpointJwks httpserver.expect_request(re.compile("^/cplane/endpoints/.+/jwks$")).respond_with_json( { "jwks": [ @@ -3542,24 +3627,22 @@ def static_auth_broker( ] } ) - auth_endpoint = httpserver.url_for("/cplane") + # return static fixture jwks. jwk = neon_authorize_jwk.export_public(as_dict=True) httpserver.expect_request("/authorize/jwks.json").respond_with_json({"keys": [jwk]}) - proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() http_port = port_distributor.get_port() external_http_port = port_distributor.get_port() - with NeonProxy( + with NeonAuthBroker( neon_binpath=neon_binpath, test_output_dir=test_output_dir, - proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, external_http_port=external_http_port, - auth_backend=NeonProxy.AuthBroker(auth_endpoint), + auth_backend=NeonAuthBroker.ControlPlane(httpserver.url_for("/cplane")), ) as proxy: proxy.start() yield proxy diff --git a/test_runner/regress/test_auth_broker.py b/test_runner/regress/test_auth_broker.py index de7964965c..11dc7d56b5 100644 --- a/test_runner/regress/test_auth_broker.py +++ b/test_runner/regress/test_auth_broker.py @@ -1,13 +1,13 @@ import json import pytest -from fixtures.neon_fixtures import NeonProxy +from fixtures.neon_fixtures import NeonAuthBroker from jwcrypto import jwk, jwt @pytest.mark.asyncio async def test_auth_broker_happy( - static_auth_broker: NeonProxy, + static_auth_broker: NeonAuthBroker, neon_authorize_jwk: jwk.JWK, ): """ @@ -18,9 +18,7 @@ async def test_auth_broker_happy( header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, claims={"sub": "user1"} ) token.make_signed_token(neon_authorize_jwk) - res = await static_auth_broker.auth_broker_query( - "foo", ["arg1"], user="anonymous", token=token.serialize() - ) + res = await static_auth_broker.query("foo", ["arg1"], user="anonymous", token=token.serialize()) # local proxy mock just echos back the request # check that we forward the correct data