tests: WIP: MITM proxy between pageserver and compute for fault testing

This commit is contained in:
Heikki Linnakangas
2024-12-05 15:00:43 +02:00
parent ffc9c33eb2
commit a98fab8b1c
2 changed files with 267 additions and 0 deletions

View File

@@ -0,0 +1,201 @@
# Intercept compute -> pageserver connections, to simulate various failure modes
from __future__ import annotations
import asyncio
import struct
import threading
import traceback
from asyncio import TaskGroup
from enum import Enum
from fixtures.log_helper import log
class ConnectionState(Enum):
HANDSHAKE = (1,)
AUTHOK = (2,)
COPYBOTH = (3,)
class BreakConnectionException(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
class ProxyShutdownException(Exception):
"""Exception raised to shut down connection when the proxy is shut down."""
# The handshake flow:
#
# 1. compute establishes TCP connection
# 2. libpq handshake and auth
# 3. Enter CopyBoth mode
#
# From then on, CopyData messages are exchanged in both directions
class Connection:
def __init__(
self,
conn_id,
compute_reader,
compute_writer,
shutdown_event,
dest_port: int,
response_cb=None,
):
self.conn_id = conn_id
self.compute_reader = compute_reader
self.compute_writer = compute_writer
self.shutdown_event = shutdown_event
self.response_cb = response_cb
self.dest_port = dest_port
self.state = ConnectionState.HANDSHAKE
async def run(self):
async def wait_for_shutdown():
await self.shutdown_event.wait()
raise ProxyShutdownException
try:
addr = self.compute_writer.get_extra_info("peername")
log.info(f"[{self.conn_id}] connection received from {addr}")
async with TaskGroup() as group:
group.create_task(wait_for_shutdown())
self.ps_reader, self.ps_writer = await asyncio.open_connection(
"localhost", self.dest_port
)
group.create_task(self.handle_compute_to_pageserver())
group.create_task(self.handle_pageserver_to_compute())
except* ProxyShutdownException:
log.info(f"[{self.conn_id}] proxy shutting down")
except* asyncio.exceptions.IncompleteReadError as e:
log.info(f"[{self.conn_id}] EOF reached: {e}")
except* BreakConnectionException as eg:
for e in eg.exceptions:
log.info(f"[{self.conn_id}] callback breaks connection: {e}")
except* Exception as e:
s = "\n".join(traceback.format_exception(e))
log.info(f"[{self.conn_id}] {s}")
self.compute_writer.close()
self.ps_writer.close()
await self.compute_writer.wait_closed()
await self.ps_writer.wait_closed()
async def handle_compute_to_pageserver(self):
while self.state == ConnectionState.HANDSHAKE:
rawmsg = await self.compute_reader.read(1000)
log.debug(f"[{self.conn_id}] C -> PS: handshake msg len {len(rawmsg)}")
self.ps_writer.write(rawmsg)
await self.ps_writer.drain()
while True:
msgtype = await self.compute_reader.readexactly(1)
msglen_bytes = await self.compute_reader.readexactly(4)
(msglen,) = struct.unpack("!L", msglen_bytes)
payload = await self.compute_reader.readexactly(msglen - 4)
# request_callback()
# CopyData
if msgtype == b"d":
# TODO: call callback
log.debug(f"[{self.conn_id}] C -> PS: CopyData ({msglen} bytes)")
pass
else:
log.debug(f"[{self.conn_id}] C -> PS: message of type '{msgtype}' ({msglen} bytes)")
self.ps_writer.write(msgtype)
self.ps_writer.write(msglen_bytes)
self.ps_writer.write(payload)
await self.ps_writer.drain()
async def handle_pageserver_to_compute(self):
while True:
msgtype = await self.ps_reader.readexactly(1)
# response to SSLRequest
if msgtype == b"N" and self.state == ConnectionState.HANDSHAKE:
log.debug(f"[{self.conn_id}] PS -> C: N")
self.compute_writer.write(msgtype)
await self.compute_writer.drain()
continue
msglen_bytes = await self.ps_reader.readexactly(4)
(msglen,) = struct.unpack("!L", msglen_bytes)
payload = await self.ps_reader.readexactly(msglen - 4)
# AuthenticationOK
if msgtype == b"R":
self.state = ConnectionState.AUTHOK
log.debug(f"[{self.conn_id}] PS -> C: AuthenticationOK ({msglen} bytes)")
# CopyBothresponse
elif msgtype == b"W":
self.state = ConnectionState.COPYBOTH
log.debug(f"[{self.conn_id}] PS -> C: CopyBothResponse ({msglen} bytes)")
# CopyData
elif msgtype == b"d":
log.debug(f"[{self.conn_id}] PS -> C: CopyData ({msglen} bytes)")
if self.response_cb is not None:
await self.response_cb(self.conn_id)
pass
else:
log.debug(f"[{self.conn_id}] PS -> C: message of type '{msgtype}' ({msglen} bytes)")
self.compute_writer.write(msgtype)
self.compute_writer.write(msglen_bytes)
self.compute_writer.write(payload)
await self.compute_writer.drain()
class PageserverProxy:
def __init__(self, listen_port: int, dest_port: int, response_cb=None):
self.listen_port = listen_port
self.dest_port = dest_port
self.response_cb = response_cb
self.conn_counter = 0
self.shutdown_event = asyncio.Event()
def shutdown(self):
self.serve_task.cancel()
self.shutdown_event.set()
async def handle_client(self, compute_reader, compute_writer):
self.conn_counter += 1
conn_id = self.conn_counter
conn = Connection(
conn_id,
compute_reader,
compute_writer,
self.shutdown_event,
self.dest_port,
self.response_cb,
)
await conn.run()
async def run_server(self):
log.info("run_server called")
server = await asyncio.start_server(self.handle_client, "localhost", self.listen_port)
self.serve_task = asyncio.create_task(server.serve_forever())
try:
await self.serve_task
except asyncio.CancelledError:
log.info("proxy shutting down")
def launch_server_in_thread(self):
t1 = threading.Thread(target=asyncio.run, args=self.run_server)
t1.start()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import random
import time
@@ -7,6 +8,8 @@ import psycopg2.errors
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder
from fixtures.pageserver_mitm import BreakConnectionException, PageserverProxy
from fixtures.port_distributor import PortDistributor
@pytest.mark.timeout(600)
@@ -80,3 +83,66 @@ def test_compute_pageserver_connection_stress(neon_env_builder: NeonEnvBuilder):
# do a graceful shutdown which would had caught the allowed_errors before
# https://github.com/neondatabase/neon/pull/8632
env.pageserver.stop()
@pytest.mark.timeout(600)
def test_compute_pageserver_connection_stress2(
neon_env_builder: NeonEnvBuilder, port_distributor: PortDistributor,
pg_bin: PgBin
):
env = neon_env_builder.init_start()
# Set up the MITM proxy
error_fraction = 0
async def response_cb(conn_id):
global error_fraction
if random.random() < error_fraction:
raise BreakConnectionException("unlucky")
mitm_listen_port = port_distributor.get_port()
mitm = PageserverProxy(mitm_listen_port, env.pageserver.service_port.pg, response_cb)
def main():
global error_fraction
endpoint = env.endpoints.create(
"main",
config_lines=[
"max_connections=1000",
])
endpoint.start()
with open(endpoint.pg_data_dir_path() / "postgresql.conf", "a") as conf:
conf.write(
f"neon.pageserver_connstring='postgres://localhost:{mitm_listen_port}' # MITM proxy\n"
)
pg_conn = endpoint.connect()
cur = pg_conn.cursor()
cur.execute("select pg_reload_conf()")
scale = 5
connstr = endpoint.connstr()
log.info(f"Start a pgbench workload on pg {connstr}")
error_fraction=0.001
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
error_fraction=0.01
pg_bin.run_capture(["pgbench", "-S", "-c80", "-j4", "-P1", "-T60", connstr])
mitm.shutdown()
async def mm():
await asyncio.gather(
asyncio.to_thread(main),
mitm.run_server()
)
asyncio.run(mm())
# do a graceful shutdown which would had caught the allowed_errors before
# https://github.com/neondatabase/neon/pull/8632
env.pageserver.stop()