mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-27 08:09:58 +00:00
## Problem `TYPE_CHECKING` is used inconsistently across Python tests. ## Summary of changes - Update `ruff`: 0.7.0 -> 0.11.2 - Enable TC (flake8-type-checking): https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc - (auto)fix all new issues
257 lines
10 KiB
Python
257 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import ssl
|
|
from typing import TYPE_CHECKING
|
|
|
|
import asyncpg
|
|
import pytest
|
|
import websocket_tunnel
|
|
import websockets
|
|
from fixtures.log_helper import log
|
|
|
|
if TYPE_CHECKING:
|
|
from fixtures.neon_fixtures import NeonProxy
|
|
from fixtures.port_distributor import PortDistributor
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websockets(static_proxy: NeonProxy):
|
|
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
|
|
|
user = "ws_auth"
|
|
password = "ws"
|
|
|
|
version = b"\x00\x03\x00\x00"
|
|
params = {
|
|
"user": user,
|
|
"database": "postgres",
|
|
"client_encoding": "UTF8",
|
|
}
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
|
|
|
async with websockets.connect(
|
|
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
|
ssl=ssl_context,
|
|
) as websocket:
|
|
startup_message = bytearray(version)
|
|
for key, value in params.items():
|
|
startup_message.extend(key.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(value.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(b"\0")
|
|
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
|
|
|
|
await websocket.send([length, bytes(startup_message)])
|
|
|
|
startup_response = await websocket.recv()
|
|
assert isinstance(startup_response, bytes)
|
|
assert startup_response[0:1] == b"R", "should be authentication message"
|
|
assert startup_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert startup_response[5:9] == b"\x00\x00\x00\x03", "should be cleartext"
|
|
|
|
auth_message = password.encode("utf-8") + b"\0"
|
|
length = (4 + len(auth_message)).to_bytes(4, byteorder="big")
|
|
await websocket.send([b"p", length, auth_message])
|
|
|
|
auth_response = await websocket.recv()
|
|
assert isinstance(auth_response, bytes)
|
|
assert auth_response[0:1] == b"R", "should be authentication message"
|
|
assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated"
|
|
|
|
query_message = b"SELECT 1" + b"\0"
|
|
length = (4 + len(query_message)).to_bytes(4, byteorder="big")
|
|
await websocket.send([b"Q", length, query_message])
|
|
|
|
query_response = await websocket.recv()
|
|
assert isinstance(query_response, bytes)
|
|
# 'T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x00\x04\xff\xff\xff\xff\x00\x00'
|
|
# 'D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x011'
|
|
# 'C\x00\x00\x00\rSELECT 1\x00'
|
|
# 'Z\x00\x00\x00\x05I'
|
|
|
|
assert query_response[0:1] == b"T", "should be row description message"
|
|
row_description_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
row_description, query_response = (
|
|
query_response[:row_description_len],
|
|
query_response[row_description_len:],
|
|
)
|
|
assert row_description[5:7] == b"\x00\x01", "should have 1 column"
|
|
assert row_description[7:16] == b"?column?\0", "column should be named ?column?"
|
|
assert row_description[22:26] == b"\x00\x00\x00\x17", "column should be an int4"
|
|
|
|
assert query_response[0:1] == b"D", "should be data row message"
|
|
data_row_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
data_row, query_response = query_response[:data_row_len], query_response[data_row_len:]
|
|
assert data_row == b"D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x011", (
|
|
"should contain 1 column with text value 1"
|
|
)
|
|
|
|
assert query_response[0:1] == b"C", "should be command complete message"
|
|
command_complete_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
command_complete, query_response = (
|
|
query_response[:command_complete_len],
|
|
query_response[command_complete_len:],
|
|
)
|
|
assert command_complete == b"C\x00\x00\x00\x0dSELECT 1\0"
|
|
|
|
assert query_response[0:6] == b"Z\x00\x00\x00\x05I", "should be ready for query (idle)"
|
|
|
|
# close
|
|
await websocket.send(b"X\x00\x00\x00\x04")
|
|
await websocket.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websockets_pipelined(static_proxy: NeonProxy):
|
|
"""
|
|
Test whether we can send the startup + auth + query all in one go
|
|
"""
|
|
|
|
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
|
|
|
user = "ws_auth"
|
|
password = "ws"
|
|
|
|
version = b"\x00\x03\x00\x00"
|
|
params = {
|
|
"user": user,
|
|
"database": "postgres",
|
|
"client_encoding": "UTF8",
|
|
}
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
|
|
|
async with websockets.connect(
|
|
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
|
ssl=ssl_context,
|
|
) as websocket:
|
|
startup_message = bytearray(version)
|
|
for key, value in params.items():
|
|
startup_message.extend(key.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(value.encode("ascii"))
|
|
startup_message.extend(b"\0")
|
|
startup_message.extend(b"\0")
|
|
length0 = (4 + len(startup_message)).to_bytes(4, byteorder="big")
|
|
|
|
auth_message = password.encode("utf-8") + b"\0"
|
|
length1 = (4 + len(auth_message)).to_bytes(4, byteorder="big")
|
|
query_message = b"SELECT 1" + b"\0"
|
|
length2 = (4 + len(query_message)).to_bytes(4, byteorder="big")
|
|
await websocket.send(
|
|
length0
|
|
+ startup_message
|
|
+ b"p"
|
|
+ length1
|
|
+ auth_message
|
|
+ b"Q"
|
|
+ length2
|
|
+ query_message
|
|
)
|
|
|
|
startup_response = await websocket.recv()
|
|
assert isinstance(startup_response, bytes)
|
|
assert startup_response[0:1] == b"R", "should be authentication message"
|
|
assert startup_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert startup_response[5:9] == b"\x00\x00\x00\x03", "should be cleartext"
|
|
|
|
auth_response = await websocket.recv()
|
|
assert isinstance(auth_response, bytes)
|
|
assert auth_response[0:1] == b"R", "should be authentication message"
|
|
assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
|
assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated"
|
|
|
|
query_response = await websocket.recv()
|
|
assert isinstance(query_response, bytes)
|
|
# 'T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x00\x04\xff\xff\xff\xff\x00\x00'
|
|
# 'D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x011'
|
|
# 'C\x00\x00\x00\rSELECT 1\x00'
|
|
# 'Z\x00\x00\x00\x05I'
|
|
|
|
assert query_response[0:1] == b"T", "should be row description message"
|
|
row_description_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
row_description, query_response = (
|
|
query_response[:row_description_len],
|
|
query_response[row_description_len:],
|
|
)
|
|
assert row_description[5:7] == b"\x00\x01", "should have 1 column"
|
|
assert row_description[7:16] == b"?column?\0", "column should be named ?column?"
|
|
assert row_description[22:26] == b"\x00\x00\x00\x17", "column should be an int4"
|
|
|
|
assert query_response[0:1] == b"D", "should be data row message"
|
|
data_row_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
data_row, query_response = query_response[:data_row_len], query_response[data_row_len:]
|
|
assert data_row == b"D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x011", (
|
|
"should contain 1 column with text value 1"
|
|
)
|
|
|
|
assert query_response[0:1] == b"C", "should be command complete message"
|
|
command_complete_len = int.from_bytes(query_response[1:5], byteorder="big") + 1
|
|
command_complete, query_response = (
|
|
query_response[:command_complete_len],
|
|
query_response[command_complete_len:],
|
|
)
|
|
assert command_complete == b"C\x00\x00\x00\x0dSELECT 1\0"
|
|
|
|
assert query_response[0:6] == b"Z\x00\x00\x00\x05I", "should be ready for query (idle)"
|
|
|
|
# close
|
|
await websocket.send(b"X\x00\x00\x00\x04")
|
|
await websocket.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websockets_tunneled(static_proxy: NeonProxy, port_distributor: PortDistributor):
|
|
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
|
|
|
user = "ws_auth"
|
|
password = "ws"
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
|
|
|
# Launch a tunnel service so that we can speak the websockets protocol to
|
|
# the proxy
|
|
tunnel_port = port_distributor.get_port()
|
|
tunnel_server = await websocket_tunnel.start_server(
|
|
"127.0.0.1",
|
|
tunnel_port,
|
|
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
|
ssl_context,
|
|
)
|
|
log.info(f"websockets tunnel listening for connections on port {tunnel_port}")
|
|
|
|
async with tunnel_server:
|
|
|
|
async def run_tunnel():
|
|
try:
|
|
async with tunnel_server:
|
|
await tunnel_server.serve_forever()
|
|
except Exception as e:
|
|
log.error(f"Error in tunnel task: {e}")
|
|
|
|
tunnel_task = asyncio.create_task(run_tunnel())
|
|
|
|
# Ok, the tunnel is now running. Check that we can connect to the proxy's
|
|
# websocket interface, through the tunnel
|
|
tunnel_connstring = f"postgres://{user}:{password}@127.0.0.1:{tunnel_port}/postgres"
|
|
|
|
log.info(f"connecting to {tunnel_connstring}")
|
|
conn = await asyncpg.connect(tunnel_connstring)
|
|
res = await conn.fetchval("SELECT 123")
|
|
assert res == 123
|
|
await conn.close()
|
|
log.info("Ran a query successfully through the tunnel")
|
|
|
|
tunnel_server.close()
|
|
try:
|
|
await tunnel_task
|
|
except asyncio.CancelledError:
|
|
pass
|