mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 07:00:38 +00:00
[auth_broker]: regress test
This commit is contained in:
31
poetry.lock
generated
31
poetry.lock
generated
@@ -1521,6 +1521,21 @@ files = [
|
||||
[package.dependencies]
|
||||
six = "*"
|
||||
|
||||
[[package]]
|
||||
name = "jwcrypto"
|
||||
version = "1.5.6"
|
||||
description = "Implementation of JOSE Web standards"
|
||||
optional = false
|
||||
python-versions = ">= 3.8"
|
||||
files = [
|
||||
{file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"},
|
||||
{file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = ">=3.4"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "kafka-python"
|
||||
version = "2.0.2"
|
||||
@@ -2111,7 +2126,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@@ -2120,8 +2134,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@@ -2603,7 +2615,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@@ -3159,16 +3170,6 @@ files = [
|
||||
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
|
||||
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
|
||||
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
|
||||
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
|
||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
|
||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
|
||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},
|
||||
@@ -3406,4 +3407,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "0f4804119f417edf8e1fbd6d715d2e8d70ad731334fa9570304a2203f83339cf"
|
||||
content-hash = "f767eaa9cb906a47372540aef37446ae55d37011be844b652eec8fb27a49d866"
|
||||
|
||||
@@ -42,6 +42,8 @@ pytest-repeat = "^0.9.3"
|
||||
websockets = "^12.0"
|
||||
clickhouse-connect = "^0.7.16"
|
||||
kafka-python = "^2.0.2"
|
||||
jwcrypto = "^1.5.6"
|
||||
h2 = "^4.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
mypy = "==1.3.0"
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
pytest_plugins = (
|
||||
"fixtures.pg_version",
|
||||
"fixtures.parametrize",
|
||||
"fixtures.h2server",
|
||||
"fixtures.httpserver",
|
||||
"fixtures.compute_reconfigure",
|
||||
"fixtures.storage_controller_proxy",
|
||||
|
||||
216
test_runner/fixtures/h2server.py
Normal file
216
test_runner/fixtures/h2server.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html
|
||||
|
||||
auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Iterator
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from h2.config import H2Configuration
|
||||
from h2.connection import H2Connection
|
||||
from h2.errors import ErrorCodes
|
||||
from h2.events import (
|
||||
ConnectionTerminated,
|
||||
DataReceived,
|
||||
RemoteSettingsChanged,
|
||||
RequestReceived,
|
||||
StreamEnded,
|
||||
StreamReset,
|
||||
WindowUpdated,
|
||||
)
|
||||
from h2.exceptions import ProtocolError, StreamClosedError
|
||||
from h2.settings import SettingCodes
|
||||
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
|
||||
RequestData = collections.namedtuple('RequestData', ['headers', 'data'])
|
||||
|
||||
class H2Server:
|
||||
def __init__(self, host, port) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
|
||||
class H2Protocol(asyncio.Protocol):
|
||||
def __init__(self):
|
||||
config = H2Configuration(client_side=False, header_encoding='utf-8')
|
||||
self.conn = H2Connection(config=config)
|
||||
self.transport = None
|
||||
self.stream_data = {}
|
||||
self.flow_control_futures = {}
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
self.transport = transport
|
||||
self.conn.initiate_connection()
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
def connection_lost(self, exc):
|
||||
for future in self.flow_control_futures.values():
|
||||
future.cancel()
|
||||
self.flow_control_futures = {}
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
assert self.transport is not None
|
||||
try:
|
||||
events = self.conn.receive_data(data)
|
||||
except ProtocolError as e:
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
self.transport.close()
|
||||
else:
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
for event in events:
|
||||
if isinstance(event, RequestReceived):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
elif isinstance(event, StreamEnded):
|
||||
self.stream_complete(event.stream_id)
|
||||
elif isinstance(event, ConnectionTerminated):
|
||||
self.transport.close()
|
||||
elif isinstance(event, StreamReset):
|
||||
self.stream_reset(event.stream_id)
|
||||
elif isinstance(event, WindowUpdated):
|
||||
self.window_updated(event.stream_id, event.delta)
|
||||
elif isinstance(event, RemoteSettingsChanged):
|
||||
if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
|
||||
self.window_updated(None, 0)
|
||||
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
def request_received(self, headers: List[Tuple[str, str]], stream_id: int):
|
||||
headers_map = collections.OrderedDict(headers)
|
||||
# method = headers_map[':method']
|
||||
|
||||
# Store off the request data.
|
||||
request_data = RequestData(headers_map, io.BytesIO())
|
||||
self.stream_data[stream_id] = request_data
|
||||
|
||||
def stream_complete(self, stream_id: int):
|
||||
"""
|
||||
When a stream is complete, we can send our response.
|
||||
"""
|
||||
try:
|
||||
request_data = self.stream_data[stream_id]
|
||||
except KeyError:
|
||||
# Just return, we probably 405'd this already
|
||||
return
|
||||
|
||||
headers = request_data.headers
|
||||
body = request_data.data.getvalue().decode('utf-8')
|
||||
|
||||
data = json.dumps(
|
||||
{"headers": headers, "body": body}, indent=4
|
||||
).encode("utf8")
|
||||
|
||||
response_headers = (
|
||||
(':status', '200'),
|
||||
('content-type', 'application/json'),
|
||||
('content-length', str(len(data))),
|
||||
)
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
asyncio.ensure_future(self.send_data(data, stream_id))
|
||||
|
||||
def receive_data(self, data: bytes, stream_id: int):
|
||||
"""
|
||||
We've received some data on a stream. If that stream is one we're
|
||||
expecting data on, save it off. Otherwise, reset the stream.
|
||||
"""
|
||||
try:
|
||||
stream_data = self.stream_data[stream_id]
|
||||
except KeyError:
|
||||
self.conn.reset_stream(
|
||||
stream_id, error_code=ErrorCodes.PROTOCOL_ERROR
|
||||
)
|
||||
else:
|
||||
stream_data.data.write(data)
|
||||
|
||||
def stream_reset(self, stream_id):
|
||||
"""
|
||||
A stream reset was sent. Stop sending data.
|
||||
"""
|
||||
if stream_id in self.flow_control_futures:
|
||||
future = self.flow_control_futures.pop(stream_id)
|
||||
future.cancel()
|
||||
|
||||
async def send_data(self, data, stream_id):
|
||||
"""
|
||||
Send data according to the flow control rules.
|
||||
"""
|
||||
while data:
|
||||
while self.conn.local_flow_control_window(stream_id) < 1:
|
||||
try:
|
||||
await self.wait_for_flow_control(stream_id)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
chunk_size = min(
|
||||
self.conn.local_flow_control_window(stream_id),
|
||||
len(data),
|
||||
self.conn.max_outbound_frame_size,
|
||||
)
|
||||
|
||||
try:
|
||||
self.conn.send_data(
|
||||
stream_id,
|
||||
data[:chunk_size],
|
||||
end_stream=(chunk_size == len(data))
|
||||
)
|
||||
except (StreamClosedError, ProtocolError):
|
||||
# The stream got closed and we didn't get told. We're done
|
||||
# here.
|
||||
break
|
||||
|
||||
assert self.transport is not None
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
data = data[chunk_size:]
|
||||
|
||||
async def wait_for_flow_control(self, stream_id):
|
||||
"""
|
||||
Waits for a Future that fires when the flow control window is opened.
|
||||
"""
|
||||
f = asyncio.Future()
|
||||
self.flow_control_futures[stream_id] = f
|
||||
await f
|
||||
|
||||
def window_updated(self, stream_id, delta):
|
||||
"""
|
||||
A window update frame was received. Unblock some number of flow control
|
||||
Futures.
|
||||
"""
|
||||
if stream_id and stream_id in self.flow_control_futures:
|
||||
f = self.flow_control_futures.pop(stream_id)
|
||||
f.set_result(delta)
|
||||
elif not stream_id:
|
||||
for f in self.flow_control_futures.values():
|
||||
f.set_result(delta)
|
||||
|
||||
self.flow_control_futures = {}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def http2_echoserver(http2_echoserver_listen_address: tuple[str, int]) -> AsyncGenerator[H2Server]:
|
||||
host, port = http2_echoserver_listen_address
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
serve = await loop.create_server(H2Protocol, host, port)
|
||||
asyncio.create_task(serve.wait_closed())
|
||||
|
||||
server = H2Server(host, port)
|
||||
yield server
|
||||
|
||||
serve.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def http2_echoserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
|
||||
port = port_distributor.get_port()
|
||||
return ("localhost", port)
|
||||
@@ -13,7 +13,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from collections.abc import AsyncGenerator, Iterable, Iterator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
@@ -35,11 +35,13 @@ import toml
|
||||
from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from jwcrypto import jwk
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from pytest_httpserver import HTTPServer
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from fixtures import overlayfs
|
||||
@@ -53,6 +55,7 @@ from fixtures.common_types import (
|
||||
TimelineId,
|
||||
)
|
||||
from fixtures.endpoint.http import EndpointHttpClient
|
||||
from fixtures.h2server import H2Server
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
|
||||
from fixtures.neon_cli import NeonLocalCli, Pagectl
|
||||
@@ -3139,6 +3142,20 @@ class NeonProxy(PgProtocol):
|
||||
]
|
||||
return args
|
||||
|
||||
class AuthBroker(AuthBackend):
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
args = [
|
||||
# Console auth backend params
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.endpoint],
|
||||
*["--sql-over-http-pool-opt-in", "false"],
|
||||
*["--is-auth-broker", "true"],
|
||||
]
|
||||
return args
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Postgres(AuthBackend):
|
||||
pg_conn_url: str
|
||||
@@ -3311,6 +3328,29 @@ class NeonProxy(PgProtocol):
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
async def auth_broker_query(self, query, args, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = kwargs["user"]
|
||||
token = kwargs["token"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
log.info(f"Executing http query: {query}")
|
||||
|
||||
connstr = f"postgresql://{user}@{self.domain}:{self.proxy_port}/postgres"
|
||||
async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client:
|
||||
response = await client.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
json={"query": query, "params": args},
|
||||
headers={
|
||||
"Neon-Connection-String": connstr,
|
||||
"Authorization": f"Bearer {token}",
|
||||
},
|
||||
)
|
||||
|
||||
if expected_code is not None:
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
|
||||
return request_result.text
|
||||
@@ -3456,6 +3496,75 @@ def static_proxy(
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_authorize_jwk() -> Iterator[jwk.JWK]:
|
||||
kid = str(uuid.uuid4())
|
||||
key = jwk.JWK.generate(kty="RSA", size=2048, alg="RS256", use="sig", kid=kid)
|
||||
yield key
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def static_auth_broker(
|
||||
port_distributor: PortDistributor,
|
||||
neon_binpath: Path,
|
||||
test_output_dir: Path,
|
||||
httpserver: HTTPServer,
|
||||
neon_authorize_jwk: jwk.JWK,
|
||||
http2_echoserver: H2Server,
|
||||
) -> Iterable[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
|
||||
# local_proxy_endpoint = httpserver.url_for("/sql")
|
||||
# local_proxy_addr = local_proxy_endpoint.removeprefix("http://").removesuffix("/sql")
|
||||
local_proxy_addr = f"{http2_echoserver.host}:{http2_echoserver.port}"
|
||||
log.info(f"local_proxy {local_proxy_addr}")
|
||||
|
||||
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
|
||||
{
|
||||
"address": local_proxy_addr,
|
||||
"aux": {
|
||||
"endpoint_id": "ep-foo-bar-1234",
|
||||
"branch_id": "br-foo-bar",
|
||||
"project_id": "foo-bar",
|
||||
},
|
||||
}
|
||||
)
|
||||
httpserver.expect_request(re.compile("^/cplane/endpoints/.+/jwks$")).respond_with_json(
|
||||
{
|
||||
"jwks": [
|
||||
{
|
||||
"id": "foo",
|
||||
"jwks_url": httpserver.url_for("/authorize/jwks.json"),
|
||||
"provider_name": "test",
|
||||
"jwt_audience": None,
|
||||
"role_names": ["anonymous", "authenticated"],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
auth_endpoint = httpserver.url_for("/cplane")
|
||||
|
||||
jwk = neon_authorize_jwk.export_public(as_dict=True)
|
||||
httpserver.expect_request("/authorize/jwks.json").respond_with_json({"keys": [jwk]})
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
|
||||
with NeonProxy(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
class Endpoint(PgProtocol, LogUtils):
|
||||
"""An object representing a Postgres compute endpoint managed by the control plane."""
|
||||
|
||||
|
||||
39
test_runner/regress/test_auth_broker.py
Normal file
39
test_runner/regress/test_auth_broker.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fixtures.neon_fixtures import NeonProxy
|
||||
from jwcrypto import jwk, jwt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_broker_happy(
|
||||
static_auth_broker: NeonProxy,
|
||||
neon_authorize_jwk: jwk.JWK,
|
||||
):
|
||||
"""
|
||||
Signs a JWT and uses it to authorize a query to local_proxy.
|
||||
"""
|
||||
|
||||
token = jwt.JWT(
|
||||
header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, claims={"sub": "user1"}
|
||||
)
|
||||
token.make_signed_token(neon_authorize_jwk)
|
||||
res = await static_auth_broker.auth_broker_query(
|
||||
"foo", ["arg1"], user="anonymous", token=token.serialize()
|
||||
)
|
||||
|
||||
# local proxy mock just echos back the request
|
||||
# check that we forward the correct data
|
||||
|
||||
assert (
|
||||
res["headers"]["authorization"] == f"Bearer {token.serialize()}"
|
||||
), "JWT should be forwarded"
|
||||
|
||||
assert (
|
||||
"anonymous" in res["headers"]["neon-connection-string"]
|
||||
), "conn string should be forwarded"
|
||||
|
||||
assert json.loads(res["body"]) == {
|
||||
"query": "foo",
|
||||
"params": ["arg1"],
|
||||
}, "Query body should be forwarded"
|
||||
Reference in New Issue
Block a user