Compare commits

...

2 Commits

Author SHA1 Message Date
Heikki Linnakangas
328408b925 Replace pgbench with python app 2024-12-05 22:14:41 +02:00
Heikki Linnakangas
a98fab8b1c tests: WIP: MITM proxy between pageserver and compute for fault testing 2024-12-05 15:00:43 +02:00
2 changed files with 400 additions and 1 deletions

View File

@@ -0,0 +1,201 @@
# Intercept compute -> pageserver connections, to simulate various failure modes
from __future__ import annotations
import asyncio
import struct
import threading
import traceback
from asyncio import TaskGroup
from enum import Enum
from fixtures.log_helper import log
class ConnectionState(Enum):
HANDSHAKE = (1,)
AUTHOK = (2,)
COPYBOTH = (3,)
class BreakConnectionException(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
class ProxyShutdownException(Exception):
"""Exception raised to shut down connection when the proxy is shut down."""
# The handshake flow:
#
# 1. compute establishes TCP connection
# 2. libpq handshake and auth
# 3. Enter CopyBoth mode
#
# From then on, CopyData messages are exchanged in both directions
class Connection:
def __init__(
self,
conn_id,
compute_reader,
compute_writer,
shutdown_event,
dest_port: int,
response_cb=None,
):
self.conn_id = conn_id
self.compute_reader = compute_reader
self.compute_writer = compute_writer
self.shutdown_event = shutdown_event
self.response_cb = response_cb
self.dest_port = dest_port
self.state = ConnectionState.HANDSHAKE
async def run(self):
async def wait_for_shutdown():
await self.shutdown_event.wait()
raise ProxyShutdownException
try:
addr = self.compute_writer.get_extra_info("peername")
log.info(f"[{self.conn_id}] connection received from {addr}")
async with TaskGroup() as group:
group.create_task(wait_for_shutdown())
self.ps_reader, self.ps_writer = await asyncio.open_connection(
"localhost", self.dest_port
)
group.create_task(self.handle_compute_to_pageserver())
group.create_task(self.handle_pageserver_to_compute())
except* ProxyShutdownException:
log.info(f"[{self.conn_id}] proxy shutting down")
except* asyncio.exceptions.IncompleteReadError as e:
log.info(f"[{self.conn_id}] EOF reached: {e}")
except* BreakConnectionException as eg:
for e in eg.exceptions:
log.info(f"[{self.conn_id}] callback breaks connection: {e}")
except* Exception as e:
s = "\n".join(traceback.format_exception(e))
log.info(f"[{self.conn_id}] {s}")
self.compute_writer.close()
self.ps_writer.close()
await self.compute_writer.wait_closed()
await self.ps_writer.wait_closed()
async def handle_compute_to_pageserver(self):
while self.state == ConnectionState.HANDSHAKE:
rawmsg = await self.compute_reader.read(1000)
log.debug(f"[{self.conn_id}] C -> PS: handshake msg len {len(rawmsg)}")
self.ps_writer.write(rawmsg)
await self.ps_writer.drain()
while True:
msgtype = await self.compute_reader.readexactly(1)
msglen_bytes = await self.compute_reader.readexactly(4)
(msglen,) = struct.unpack("!L", msglen_bytes)
payload = await self.compute_reader.readexactly(msglen - 4)
# request_callback()
# CopyData
if msgtype == b"d":
# TODO: call callback
log.debug(f"[{self.conn_id}] C -> PS: CopyData ({msglen} bytes)")
pass
else:
log.debug(f"[{self.conn_id}] C -> PS: message of type '{msgtype}' ({msglen} bytes)")
self.ps_writer.write(msgtype)
self.ps_writer.write(msglen_bytes)
self.ps_writer.write(payload)
await self.ps_writer.drain()
async def handle_pageserver_to_compute(self):
while True:
msgtype = await self.ps_reader.readexactly(1)
# response to SSLRequest
if msgtype == b"N" and self.state == ConnectionState.HANDSHAKE:
log.debug(f"[{self.conn_id}] PS -> C: N")
self.compute_writer.write(msgtype)
await self.compute_writer.drain()
continue
msglen_bytes = await self.ps_reader.readexactly(4)
(msglen,) = struct.unpack("!L", msglen_bytes)
payload = await self.ps_reader.readexactly(msglen - 4)
# AuthenticationOK
if msgtype == b"R":
self.state = ConnectionState.AUTHOK
log.debug(f"[{self.conn_id}] PS -> C: AuthenticationOK ({msglen} bytes)")
# CopyBothresponse
elif msgtype == b"W":
self.state = ConnectionState.COPYBOTH
log.debug(f"[{self.conn_id}] PS -> C: CopyBothResponse ({msglen} bytes)")
# CopyData
elif msgtype == b"d":
log.debug(f"[{self.conn_id}] PS -> C: CopyData ({msglen} bytes)")
if self.response_cb is not None:
await self.response_cb(self.conn_id)
pass
else:
log.debug(f"[{self.conn_id}] PS -> C: message of type '{msgtype}' ({msglen} bytes)")
self.compute_writer.write(msgtype)
self.compute_writer.write(msglen_bytes)
self.compute_writer.write(payload)
await self.compute_writer.drain()
class PageserverProxy:
def __init__(self, listen_port: int, dest_port: int, response_cb=None):
self.listen_port = listen_port
self.dest_port = dest_port
self.response_cb = response_cb
self.conn_counter = 0
self.shutdown_event = asyncio.Event()
def shutdown(self):
self.serve_task.cancel()
self.shutdown_event.set()
async def handle_client(self, compute_reader, compute_writer):
self.conn_counter += 1
conn_id = self.conn_counter
conn = Connection(
conn_id,
compute_reader,
compute_writer,
self.shutdown_event,
self.dest_port,
self.response_cb,
)
await conn.run()
async def run_server(self):
log.info("run_server called")
server = await asyncio.start_server(self.handle_client, "localhost", self.listen_port)
self.serve_task = asyncio.create_task(server.serve_forever())
try:
await self.serve_task
except asyncio.CancelledError:
log.info("proxy shutting down")
def launch_server_in_thread(self):
t1 = threading.Thread(target=asyncio.run, args=self.run_server)
t1.start()

View File

@@ -1,12 +1,16 @@
from __future__ import annotations
import asyncio
import random
import time
from asyncio import TaskGroup
import psycopg2.errors
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin
from fixtures.pageserver_mitm import BreakConnectionException, PageserverProxy
from fixtures.port_distributor import PortDistributor
@pytest.mark.timeout(600)
@@ -80,3 +84,197 @@ def test_compute_pageserver_connection_stress(neon_env_builder: NeonEnvBuilder):
# do a graceful shutdown which would had caught the allowed_errors before
# https://github.com/neondatabase/neon/pull/8632
env.pageserver.stop()
#
# Observations:
#
# 1. If the backend is waiting for response to GetPage request, and the client disconnects,
# the backend will not immediately abort the GetPage request. It will not notice that the
# client is gone, until it tries to send something back to the client, or if a timeout
# kills the query.
#
# So to reproduce the traffic jam, you need:
#
# - A network glitch, which causes one GetPage request/response to be lost or delayed.
# It might get lost at IP level, and TCP retransmits might take a long time. Or there might
# be a glitch in the pageserver or compute, which causes the request to be "stuck".
#
# - An application with a application level timeout and retry. If the
# query doesn't return in a timely a fashion, the application kills the connection and
# retries the query, or a runs similar query that needs the same page.
#
# The first time the GetPage request is stuck and it disconnects, it leaves behind a
# backend that's still waiting for the GetPage response, and is holding the buffer lock.
# The client has closed the connection, but the server doesn't get the memo.
# On each subsequent retry, the connection will block waiting for the buffer lock, give up,
# and leave behind another backend blocked indefinitely.
#
# The situation unravels when the original backend doing the GetPage request finally
# gets a response, or it gets confirmation that the TCP connection is lost.
#
# This test reproduces the traffic jam using a MITM proxy between pageserver and compute,
# and forcing one GetPage request to get stuck.
#
# Recommendations:
# - set client_connection_check_interval = '10s'. This makes Postgres wake up and check
# for client connection loss. It's not perfect, it might not notice if the client has
# e.g rebooted without sending a RST packet, but there's no downside
#
# - Add a timeout to GetPage requests. If no response is received from the pageserver
# in, say, 10 s, terminate the pageserver connection and retry. XXX: Negotiate this
# behavior with the storage team
#
#
@pytest.mark.timeout(120)
def test_compute_pageserver_connection_stress2(
neon_env_builder: NeonEnvBuilder, port_distributor: PortDistributor, pg_bin: PgBin
):
env = neon_env_builder.init_start()
# Set up the MITM proxy
global error_fraction
global delay_fraction
error_fraction = 0
delay_fraction = 0
async def response_cb(conn_id):
global delay_fraction
global error_fraction
if random.random() < error_fraction:
raise BreakConnectionException("unlucky")
orig_delay_fraction = delay_fraction
if random.random() < delay_fraction:
delay_fraction = 0
log.info(f"[{conn_id}] making getpage request STUCK")
try:
await asyncio.sleep(300)
finally:
delay_fraction = orig_delay_fraction
log.info(f"[{conn_id}] delay finished")
mitm_listen_port = port_distributor.get_port()
mitm = PageserverProxy(mitm_listen_port, env.pageserver.service_port.pg, response_cb)
def main():
global error_fraction, delay_fraction
endpoint = env.endpoints.create(
"main",
config_lines=[
"max_connections=1000",
"shared_buffers=8MB",
"log_connections=on",
"log_disconnections=on",
],
)
endpoint.start()
with open(endpoint.pg_data_dir_path() / "postgresql.conf", "a") as conf:
conf.write(
f"neon.pageserver_connstring='postgres://localhost:{mitm_listen_port}' # MITM proxy\n"
)
pg_conn = endpoint.connect()
cur = pg_conn.cursor()
cur.execute("select pg_reload_conf()")
scale = 5
connstr = endpoint.connstr()
log.info(f"Start a pgbench workload on pg {connstr}")
error_fraction = 0.001
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
error_fraction = 0.00
delay_fraction = 0.001
cur.execute("select max(aid) from pgbench_accounts")
num_accounts = 100000 * scale
num_clients = 200
app = WalkingApplication(num_accounts, num_clients, endpoint, 1000000)
asyncio.run(app.run())
mitm.shutdown()
async def mm():
await asyncio.gather(asyncio.to_thread(main), mitm.run_server())
asyncio.run(mm())
# do a graceful shutdown which would had caught the allowed_errors before
# https://github.com/neondatabase/neon/pull/8632
env.pageserver.stop()
class WalkingApplication:
"""
A test application with following characteristics:
- It performs single-row lookups in pgbench_accounts table, just like pgbench -S
- Whenever a query takes longer than 10s, the application disconnects, reconnects,
and retries the query, with the same parameter. This way, if there's a problem
with a single page, the application will keep retrying it rather than work
around it.
- The lookups are not randomly distributed, but form a "walking herd" pattern,
where the queries walk through all accounts, with some randomness. This way,
there's a lot of locality of access, but the locality moves throughout the
table.
"""
def __init__(self, num_accounts, num_clients, endpoint, num_xacts):
self.num_accounts = num_accounts
self.num_clients = num_clients
self.endpoint = endpoint
self.running = True
self.num_xacts = num_xacts
self.xacts_started = 0
self.xacts_performed = 0
self.xacts_failed = 0
async def run(self):
async with TaskGroup() as group:
for i in range(1, self.num_clients):
group.create_task(self.walking_client(i))
async def walking_client(self, client_id):
local_xacts_performed = 0
conn = None
stmt = None
failed = False
while self.running and self.xacts_started < self.num_xacts:
self.xacts_started += 1
if not failed:
aid = (self.xacts_started * 100 + random.randint(0, 100)) % self.num_accounts + 1
if conn is None:
conn = await self.endpoint.connect_async()
await conn.execute("set statement_timeout=0")
stmt = await conn.prepare("SELECT abalance FROM pgbench_accounts WHERE aid = $1")
try:
async with asyncio.timeout(10):
res = await stmt.fetchval(aid)
if local_xacts_performed % 1000 == 0:
log.info(
f"[{client_id}] result {self.xacts_performed}/{self.num_xacts}: balance of account {aid}: {res}"
)
self.xacts_performed += 1
local_xacts_performed += 1
failed = False
except TimeoutError:
log.info(f"[{client_id}] query on aid {aid} timed out. Reconnecting")
conn.terminate()
conn = None
failed = True