mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-26 09:30:37 +00:00
slight refactor
This commit is contained in:
@@ -3095,6 +3095,31 @@ class PSQL:
|
||||
)
|
||||
|
||||
|
||||
def generate_proxy_tls_certs(common_name: str, key_path: Path, crt_path: Path):
|
||||
if not key_path.exists():
|
||||
r = subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-new",
|
||||
"-x509",
|
||||
"-days",
|
||||
"365",
|
||||
"-nodes",
|
||||
"-text",
|
||||
"-out",
|
||||
str(crt_path),
|
||||
"-keyout",
|
||||
str(key_path),
|
||||
"-subj",
|
||||
f"/CN={common_name}",
|
||||
"-addext",
|
||||
f"subjectAltName = DNS:{common_name}",
|
||||
]
|
||||
)
|
||||
assert r.returncode == 0
|
||||
|
||||
|
||||
class NeonProxy(PgProtocol):
|
||||
link_auth_uri: str = "http://dummy-uri"
|
||||
|
||||
@@ -3142,20 +3167,6 @@ class NeonProxy(PgProtocol):
|
||||
]
|
||||
return args
|
||||
|
||||
class AuthBroker(AuthBackend):
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
args = [
|
||||
# Console auth backend params
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.endpoint],
|
||||
*["--sql-over-http-pool-opt-in", "false"],
|
||||
*["--is-auth-broker", "true"],
|
||||
]
|
||||
return args
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Postgres(AuthBackend):
|
||||
pg_conn_url: str
|
||||
@@ -3207,29 +3218,7 @@ class NeonProxy(PgProtocol):
|
||||
# generate key of it doesn't exist
|
||||
crt_path = self.test_output_dir / "proxy.crt"
|
||||
key_path = self.test_output_dir / "proxy.key"
|
||||
|
||||
if not key_path.exists():
|
||||
r = subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-new",
|
||||
"-x509",
|
||||
"-days",
|
||||
"365",
|
||||
"-nodes",
|
||||
"-text",
|
||||
"-out",
|
||||
str(crt_path),
|
||||
"-keyout",
|
||||
str(key_path),
|
||||
"-subj",
|
||||
"/CN=*.localtest.me",
|
||||
"-addext",
|
||||
"subjectAltName = DNS:*.localtest.me",
|
||||
]
|
||||
)
|
||||
assert r.returncode == 0
|
||||
generate_proxy_tls_certs("*.localtest.me", key_path, crt_path)
|
||||
|
||||
args = [
|
||||
str(self.neon_binpath / "proxy"),
|
||||
@@ -3328,29 +3317,6 @@ class NeonProxy(PgProtocol):
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
async def auth_broker_query(self, query, args, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = kwargs["user"]
|
||||
token = kwargs["token"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
log.info(f"Executing http query: {query}")
|
||||
|
||||
connstr = f"postgresql://{user}@{self.domain}:{self.proxy_port}/postgres"
|
||||
async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client:
|
||||
response = await client.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
json={"query": query, "params": args},
|
||||
headers={
|
||||
"Neon-Connection-String": connstr,
|
||||
"Authorization": f"Bearer {token}",
|
||||
},
|
||||
)
|
||||
|
||||
if expected_code is not None:
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
|
||||
return request_result.text
|
||||
@@ -3432,6 +3398,125 @@ class NeonProxy(PgProtocol):
|
||||
assert out == "ok"
|
||||
|
||||
|
||||
class NeonAuthBroker:
|
||||
class ControlPlane:
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
args = [
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.endpoint],
|
||||
]
|
||||
return args
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neon_binpath: Path,
|
||||
test_output_dir: Path,
|
||||
http_port: int,
|
||||
mgmt_port: int,
|
||||
external_http_port: int,
|
||||
auth_backend: NeonAuthBroker.ControlPlane,
|
||||
):
|
||||
self.domain = "apiauth.localtest.me" # resolves to 127.0.0.1
|
||||
self.host = "127.0.0.1"
|
||||
self.http_port = http_port
|
||||
self.external_http_port = external_http_port
|
||||
self.neon_binpath = neon_binpath
|
||||
self.test_output_dir = test_output_dir
|
||||
self.mgmt_port = mgmt_port
|
||||
self.auth_backend = auth_backend
|
||||
self.http_timeout_seconds = 15
|
||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
def start(self) -> NeonAuthBroker:
|
||||
assert self._popen is None
|
||||
|
||||
# generate key of it doesn't exist
|
||||
crt_path = self.test_output_dir / "proxy.crt"
|
||||
key_path = self.test_output_dir / "proxy.key"
|
||||
generate_proxy_tls_certs("apiauth.localtest.me", key_path, crt_path)
|
||||
|
||||
args = [
|
||||
str(self.neon_binpath / "proxy"),
|
||||
*["--http", f"{self.host}:{self.http_port}"],
|
||||
*["--mgmt", f"{self.host}:{self.mgmt_port}"],
|
||||
*["--wss", f"{self.host}:{self.external_http_port}"],
|
||||
*["-c", str(crt_path)],
|
||||
*["-k", str(key_path)],
|
||||
*["--sql-over-http-pool-opt-in", "false"],
|
||||
*["--is-auth-broker", "true"],
|
||||
*self.auth_backend.extra_args(),
|
||||
]
|
||||
|
||||
logfile = open(self.test_output_dir / "proxy.log", "w")
|
||||
self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile)
|
||||
self._wait_until_ready()
|
||||
return self
|
||||
|
||||
# Sends SIGTERM to the proxy if it has been started
|
||||
def terminate(self):
|
||||
if self._popen:
|
||||
self._popen.terminate()
|
||||
|
||||
# Waits for proxy to exit if it has been opened with a default timeout of
|
||||
# two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time.
|
||||
def wait_for_exit(self, timeout=2):
|
||||
if self._popen:
|
||||
self._popen.wait(timeout=timeout)
|
||||
|
||||
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
|
||||
def _wait_until_ready(self):
|
||||
assert (
|
||||
self._popen and self._popen.poll() is None
|
||||
), "Proxy exited unexpectedly. Check test log."
|
||||
requests.get(f"http://{self.host}:{self.http_port}/v1/status")
|
||||
|
||||
async def query(self, query, args, **kwargs):
|
||||
user = kwargs["user"]
|
||||
token = kwargs["token"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
log.info(f"Executing http query: {query}")
|
||||
|
||||
connstr = f"postgresql://{user}@{self.domain}/postgres"
|
||||
async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client:
|
||||
response = await client.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
json={"query": query, "params": args},
|
||||
headers={
|
||||
"Neon-Connection-String": connstr,
|
||||
"Authorization": f"Bearer {token}",
|
||||
},
|
||||
)
|
||||
|
||||
if expected_code is not None:
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
|
||||
return request_result.text
|
||||
|
||||
def __enter__(self) -> NeonAuthBroker:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
_exc_type: Optional[type[BaseException]],
|
||||
_exc_value: Optional[BaseException],
|
||||
_traceback: Optional[TracebackType],
|
||||
):
|
||||
if self._popen is not None:
|
||||
self._popen.terminate()
|
||||
try:
|
||||
self._popen.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
log.warning("failed to gracefully terminate proxy; killing")
|
||||
self._popen.kill()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def link_proxy(
|
||||
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
|
||||
@@ -3497,10 +3582,10 @@ def static_proxy(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_authorize_jwk() -> Iterator[jwk.JWK]:
|
||||
def neon_authorize_jwk() -> jwk.JWK:
|
||||
kid = str(uuid.uuid4())
|
||||
key = jwk.JWK.generate(kty="RSA", size=2048, alg="RS256", use="sig", kid=kid)
|
||||
yield key
|
||||
return key
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -3511,14 +3596,12 @@ def static_auth_broker(
|
||||
httpserver: HTTPServer,
|
||||
neon_authorize_jwk: jwk.JWK,
|
||||
http2_echoserver: H2Server,
|
||||
) -> Iterable[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
) -> Iterable[NeonAuthBroker]:
|
||||
"""Neon Auth Broker that routes to a mocked local_proxy and a mocked cplane HTTP API."""
|
||||
|
||||
# local_proxy_endpoint = httpserver.url_for("/sql")
|
||||
# local_proxy_addr = local_proxy_endpoint.removeprefix("http://").removesuffix("/sql")
|
||||
local_proxy_addr = f"{http2_echoserver.host}:{http2_echoserver.port}"
|
||||
log.info(f"local_proxy {local_proxy_addr}")
|
||||
|
||||
# return local_proxy addr on ProxyWakeCompute.
|
||||
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
|
||||
{
|
||||
"address": local_proxy_addr,
|
||||
@@ -3529,6 +3612,8 @@ def static_auth_broker(
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# return jwks mock addr on GetEndpointJwks
|
||||
httpserver.expect_request(re.compile("^/cplane/endpoints/.+/jwks$")).respond_with_json(
|
||||
{
|
||||
"jwks": [
|
||||
@@ -3542,24 +3627,22 @@ def static_auth_broker(
|
||||
]
|
||||
}
|
||||
)
|
||||
auth_endpoint = httpserver.url_for("/cplane")
|
||||
|
||||
# return static fixture jwks.
|
||||
jwk = neon_authorize_jwk.export_public(as_dict=True)
|
||||
httpserver.expect_request("/authorize/jwks.json").respond_with_json({"keys": [jwk]})
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
|
||||
with NeonProxy(
|
||||
with NeonAuthBroker(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
|
||||
auth_backend=NeonAuthBroker.ControlPlane(httpserver.url_for("/cplane")),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fixtures.neon_fixtures import NeonProxy
|
||||
from fixtures.neon_fixtures import NeonAuthBroker
|
||||
from jwcrypto import jwk, jwt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_broker_happy(
|
||||
static_auth_broker: NeonProxy,
|
||||
static_auth_broker: NeonAuthBroker,
|
||||
neon_authorize_jwk: jwk.JWK,
|
||||
):
|
||||
"""
|
||||
@@ -18,9 +18,7 @@ async def test_auth_broker_happy(
|
||||
header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, claims={"sub": "user1"}
|
||||
)
|
||||
token.make_signed_token(neon_authorize_jwk)
|
||||
res = await static_auth_broker.auth_broker_query(
|
||||
"foo", ["arg1"], user="anonymous", token=token.serialize()
|
||||
)
|
||||
res = await static_auth_broker.query("foo", ["arg1"], user="anonymous", token=token.serialize())
|
||||
|
||||
# local proxy mock just echos back the request
|
||||
# check that we forward the correct data
|
||||
|
||||
Reference in New Issue
Block a user