mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 23:12:54 +00:00
Add a websockets tunnel and a test for the proxy's websockets support. (#3823)
For testing the proxy's websockets support. I wrote this to test https://github.com/neondatabase/neon/issues/3822. Unfortunately, that bug can *not* be reproduced with this tunnel. The bug only appears when the client pipelines the first query with the authentication messages. The tunnel doesn't do that. --- Update (@conradludgate 2025-01-10): We have since added some websocket tests, but they manually implemented a very simplistic setup of the postgres protocol. Introducing the tunnel would make more complex testing simpler in the future. --------- Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
This commit is contained in:
committed by
GitHub
parent
12053cf832
commit
09fe3b025c
@@ -1,10 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import websocket_tunnel
|
||||
import websockets
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import NeonProxy
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -196,3 +201,53 @@ async def test_websockets_pipelined(static_proxy: NeonProxy):
|
||||
# close
|
||||
await websocket.send(b"X\x00\x00\x00\x04")
|
||||
await websocket.wait_closed()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websockets_tunneled(static_proxy: NeonProxy, port_distributor: PortDistributor):
|
||||
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
||||
|
||||
user = "ws_auth"
|
||||
password = "ws"
|
||||
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
||||
|
||||
# Launch a tunnel service so that we can speak the websockets protocol to
|
||||
# the proxy
|
||||
tunnel_port = port_distributor.get_port()
|
||||
tunnel_server = await websocket_tunnel.start_server(
|
||||
"127.0.0.1",
|
||||
tunnel_port,
|
||||
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
ssl_context,
|
||||
)
|
||||
log.info(f"websockets tunnel listening for connections on port {tunnel_port}")
|
||||
|
||||
async with tunnel_server:
|
||||
|
||||
async def run_tunnel():
|
||||
try:
|
||||
async with tunnel_server:
|
||||
await tunnel_server.serve_forever()
|
||||
except Exception as e:
|
||||
log.error(f"Error in tunnel task: {e}")
|
||||
|
||||
tunnel_task = asyncio.create_task(run_tunnel())
|
||||
|
||||
# Ok, the tunnel is now running. Check that we can connect to the proxy's
|
||||
# websocket interface, through the tunnel
|
||||
tunnel_connstring = f"postgres://{user}:{password}@127.0.0.1:{tunnel_port}/postgres"
|
||||
|
||||
log.info(f"connecting to {tunnel_connstring}")
|
||||
conn = await asyncpg.connect(tunnel_connstring)
|
||||
res = await conn.fetchval("SELECT 123")
|
||||
assert res == 123
|
||||
await conn.close()
|
||||
log.info("Ran a query successfully through the tunnel")
|
||||
|
||||
tunnel_server.close()
|
||||
try:
|
||||
await tunnel_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
154
test_runner/websocket_tunnel.py
Executable file
154
test_runner/websocket_tunnel.py
Executable file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# This program helps to test the WebSocket tunneling in proxy. It listens for a TCP
|
||||
# connection on a port, and when you connect to it, it opens a websocket connection,
|
||||
# and forwards all the traffic to the websocket connection, wrapped in WebSocket binary
|
||||
# frames.
|
||||
#
|
||||
# This is used in the test_proxy::test_websockets test, but it is handy for manual testing too.
|
||||
#
|
||||
# Usage for manual testing:
|
||||
#
|
||||
# ## Launch Posgres on port 3000:
|
||||
# postgres -D data -p3000
|
||||
#
|
||||
# ## Launch proxy with WSS enabled:
|
||||
# openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj '/CN=*.neon.localtest.me'
|
||||
# ./target/debug/proxy --wss 127.0.0.1:40433 --http 127.0.0.1:28080 --mgmt 127.0.0.1:9099 --proxy 127.0.0.1:4433 --tls-key server.key --tls-cert server.crt --auth-backend postgres
|
||||
#
|
||||
# ## Launch the tunnel:
|
||||
#
|
||||
# poetry run ./test_runner/websocket_tunnel.py --ws-port 40433 --ws-url "wss://ep-test.neon.localtest.me"
|
||||
#
|
||||
# ## Now you can connect with psql:
|
||||
# psql "postgresql://heikki@localhost:40433/postgres"
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
from ssl import Purpose
|
||||
|
||||
import websockets
|
||||
from fixtures.log_helper import log
|
||||
|
||||
|
||||
# Enable verbose logging of all the traffic
|
||||
def enable_verbose_logging():
|
||||
logger = logging.getLogger("websockets")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
|
||||
|
||||
async def start_server(tcp_listen_host, tcp_listen_port, ws_url, ctx):
|
||||
server = await asyncio.start_server(
|
||||
lambda r, w: handle_client(r, w, ws_url, ctx), tcp_listen_host, tcp_listen_port
|
||||
)
|
||||
return server
|
||||
|
||||
|
||||
async def handle_tcp_to_websocket(tcp_reader, ws):
|
||||
try:
|
||||
while not tcp_reader.at_eof():
|
||||
data = await tcp_reader.read(1024)
|
||||
|
||||
await ws.send(data)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
log.debug(f"connection closed: {e}")
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
log.debug("connection closed")
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
|
||||
|
||||
async def handle_websocket_to_tcp(ws, tcp_writer):
|
||||
try:
|
||||
async for message in ws:
|
||||
tcp_writer.write(message)
|
||||
await tcp_writer.drain()
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
log.debug(f"connection closed: {e}")
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
log.debug("connection closed")
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
|
||||
|
||||
async def handle_client(tcp_reader, tcp_writer, ws_url: str, ctx: ssl.SSLContext):
|
||||
try:
|
||||
log.info("Received TCP connection. Connecting to websockets proxy.")
|
||||
|
||||
async with websockets.connect(ws_url, ssl=ctx) as ws:
|
||||
try:
|
||||
log.info("Connected to websockets proxy")
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
task1 = tg.create_task(handle_tcp_to_websocket(tcp_reader, ws))
|
||||
task2 = tg.create_task(handle_websocket_to_tcp(ws, tcp_writer))
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[task1, task2], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
tcp_writer.close()
|
||||
await ws.close()
|
||||
|
||||
except* Exception as ex:
|
||||
log.error(ex.exceptions)
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tcp-listen-addr",
|
||||
default="localhost",
|
||||
help="TCP addr to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tcp-listen-port",
|
||||
default="40444",
|
||||
help="TCP port to listen on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ws-url",
|
||||
default="wss://localhost/",
|
||||
help="websocket URL to connect to. This determines the Host header sent to the server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ws-host",
|
||||
default="127.0.0.1",
|
||||
help="websockets host to connect to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
default=443,
|
||||
help="websockets port to connect to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="enable verbose logging",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
enable_verbose_logging()
|
||||
|
||||
ctx = ssl.create_default_context(Purpose.SERVER_AUTH)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
server = await start_server(args.tcp_listen_addr, args.tcp_listen_port, args.ws_url, ctx)
|
||||
print(
|
||||
f"Listening for connections at {args.tcp_listen_addr}:{args.tcp_listen_port}, forwarding them to {args.ws_host}:{args.ws_port}"
|
||||
)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user