Files
neon/test_runner/regress/test_proxy_websockets.py
Conrad Ludgate 3e4265d706 ruff fmt
2024-01-12 13:04:40 +00:00

63 lines
2.3 KiB
Python

import ssl
import pytest
import websockets
from fixtures.neon_fixtures import NeonProxy
@pytest.mark.asyncio
async def test_websockets(static_proxy: NeonProxy):
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
user = "ws_auth"
password = "ws"
version = b"\x00\x03\x00\x00"
params = {
"user": user,
"database": "postgres",
"client_encoding": "UTF8",
}
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
async with websockets.connect(
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
ssl=ssl_context,
) as websocket:
startup_message = bytearray(version)
for key, value in params.items():
startup_message.extend(key.encode("ascii"))
startup_message.extend(b"\0")
startup_message.extend(value.encode("ascii"))
startup_message.extend(b"\0")
startup_message.extend(b"\0")
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
await websocket.send([length, startup_message])
startup_response = await websocket.recv()
assert startup_response[0:1] == b"R", "should be authentication message"
assert startup_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
assert startup_response[5:9] == b"\x00\x00\x00\x03", "should be cleartext"
auth_message = password.encode("utf-8") + b"\0"
length = (4 + len(auth_message)).to_bytes(4, byteorder="big")
await websocket.send([b"p", length, auth_message])
auth_response = await websocket.recv()
assert auth_response[0:1] == b"R", "should be authentication message"
assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated"
query_message = "SELECT 1".encode("utf-8") + b"\0"
length = (4 + len(query_message)).to_bytes(4, byteorder="big")
await websocket.send([b"Q", length, query_message])
_query_response = await websocket.recv()
# close
await websocket.send(b"X\x00\x00\x00\x04")
await websocket.wait_closed()