mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 18:02:56 +00:00
63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
import ssl
|
|
|
|
import pytest
|
|
import websockets
|
|
from fixtures.neon_fixtures import NeonProxy
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websockets(static_proxy: NeonProxy):
|
|
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
|
|
|
user = "ws_auth"
|
|
password = "ws"
|
|
|
|
version = b"\x00\x03\x00\x00"
|
|
params = {
|
|
"user": user,
|
|
"database": "postgres",
|
|
"client_encoding": "UTF8",
|
|
}
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
|
|
|
async with websockets.connect(
|
|
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
|
ssl=ssl_context,
|
|
) as websocket:
|
|
startup_message = bytearray(version)
|
|
for key, value in params.items():
|
|
startup_message.extend(key.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(value.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(b"\0")
|
|
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
|
|
|
|
await websocket.send([length, startup_message])
|
|
|
|
startup_response = await websocket.recv()
|
|
assert startup_response[0:1] == b"R", "should be authentication message"
|
|
assert startup_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert startup_response[5:9] == b"\x00\x00\x00\x03", "should be cleartext"
|
|
|
|
auth_message = password.encode("utf-8") + b"\0"
|
|
length = (4 + len(auth_message)).to_bytes(4, byteorder="big")
|
|
await websocket.send([b"p", length, auth_message])
|
|
|
|
auth_response = await websocket.recv()
|
|
assert auth_response[0:1] == b"R", "should be authentication message"
|
|
assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated"
|
|
|
|
query_message = "SELECT 1".encode("utf-8") + b"\0"
|
|
length = (4 + len(query_message)).to_bytes(4, byteorder="big")
|
|
await websocket.send([b"Q", length, query_message])
|
|
|
|
_query_response = await websocket.recv()
|
|
|
|
# close
|
|
await websocket.send(b"X\x00\x00\x00\x04")
|
|
await websocket.wait_closed()
|