mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-31 09:10:38 +00:00
Compare commits
2 Commits
release-co
...
heikki/mit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
328408b925 | ||
|
|
a98fab8b1c |
201
test_runner/fixtures/pageserver_mitm.py
Normal file
201
test_runner/fixtures/pageserver_mitm.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Intercept compute -> pageserver connections, to simulate various failure modes
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import threading
|
||||
import traceback
|
||||
from asyncio import TaskGroup
|
||||
from enum import Enum
|
||||
|
||||
from fixtures.log_helper import log
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
HANDSHAKE = (1,)
|
||||
AUTHOK = (2,)
|
||||
COPYBOTH = (3,)
|
||||
|
||||
|
||||
class BreakConnectionException(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class ProxyShutdownException(Exception):
|
||||
"""Exception raised to shut down connection when the proxy is shut down."""
|
||||
|
||||
|
||||
# The handshake flow:
|
||||
#
|
||||
# 1. compute establishes TCP connection
|
||||
# 2. libpq handshake and auth
|
||||
# 3. Enter CopyBoth mode
|
||||
#
|
||||
# From then on, CopyData messages are exchanged in both directions
|
||||
class Connection:
|
||||
def __init__(
|
||||
self,
|
||||
conn_id,
|
||||
compute_reader,
|
||||
compute_writer,
|
||||
shutdown_event,
|
||||
dest_port: int,
|
||||
response_cb=None,
|
||||
):
|
||||
self.conn_id = conn_id
|
||||
self.compute_reader = compute_reader
|
||||
self.compute_writer = compute_writer
|
||||
self.shutdown_event = shutdown_event
|
||||
self.response_cb = response_cb
|
||||
self.dest_port = dest_port
|
||||
|
||||
self.state = ConnectionState.HANDSHAKE
|
||||
|
||||
async def run(self):
|
||||
async def wait_for_shutdown():
|
||||
await self.shutdown_event.wait()
|
||||
raise ProxyShutdownException
|
||||
|
||||
try:
|
||||
addr = self.compute_writer.get_extra_info("peername")
|
||||
log.info(f"[{self.conn_id}] connection received from {addr}")
|
||||
|
||||
async with TaskGroup() as group:
|
||||
group.create_task(wait_for_shutdown())
|
||||
|
||||
self.ps_reader, self.ps_writer = await asyncio.open_connection(
|
||||
"localhost", self.dest_port
|
||||
)
|
||||
|
||||
group.create_task(self.handle_compute_to_pageserver())
|
||||
group.create_task(self.handle_pageserver_to_compute())
|
||||
|
||||
except* ProxyShutdownException:
|
||||
log.info(f"[{self.conn_id}] proxy shutting down")
|
||||
|
||||
except* asyncio.exceptions.IncompleteReadError as e:
|
||||
log.info(f"[{self.conn_id}] EOF reached: {e}")
|
||||
|
||||
except* BreakConnectionException as eg:
|
||||
for e in eg.exceptions:
|
||||
log.info(f"[{self.conn_id}] callback breaks connection: {e}")
|
||||
|
||||
except* Exception as e:
|
||||
s = "\n".join(traceback.format_exception(e))
|
||||
log.info(f"[{self.conn_id}] {s}")
|
||||
|
||||
self.compute_writer.close()
|
||||
self.ps_writer.close()
|
||||
await self.compute_writer.wait_closed()
|
||||
await self.ps_writer.wait_closed()
|
||||
|
||||
async def handle_compute_to_pageserver(self):
|
||||
while self.state == ConnectionState.HANDSHAKE:
|
||||
rawmsg = await self.compute_reader.read(1000)
|
||||
log.debug(f"[{self.conn_id}] C -> PS: handshake msg len {len(rawmsg)}")
|
||||
self.ps_writer.write(rawmsg)
|
||||
await self.ps_writer.drain()
|
||||
|
||||
while True:
|
||||
msgtype = await self.compute_reader.readexactly(1)
|
||||
msglen_bytes = await self.compute_reader.readexactly(4)
|
||||
(msglen,) = struct.unpack("!L", msglen_bytes)
|
||||
payload = await self.compute_reader.readexactly(msglen - 4)
|
||||
|
||||
# request_callback()
|
||||
# CopyData
|
||||
if msgtype == b"d":
|
||||
# TODO: call callback
|
||||
log.debug(f"[{self.conn_id}] C -> PS: CopyData ({msglen} bytes)")
|
||||
pass
|
||||
else:
|
||||
log.debug(f"[{self.conn_id}] C -> PS: message of type '{msgtype}' ({msglen} bytes)")
|
||||
|
||||
self.ps_writer.write(msgtype)
|
||||
self.ps_writer.write(msglen_bytes)
|
||||
self.ps_writer.write(payload)
|
||||
await self.ps_writer.drain()
|
||||
|
||||
async def handle_pageserver_to_compute(self):
|
||||
while True:
|
||||
msgtype = await self.ps_reader.readexactly(1)
|
||||
|
||||
# response to SSLRequest
|
||||
if msgtype == b"N" and self.state == ConnectionState.HANDSHAKE:
|
||||
log.debug(f"[{self.conn_id}] PS -> C: N")
|
||||
self.compute_writer.write(msgtype)
|
||||
await self.compute_writer.drain()
|
||||
continue
|
||||
|
||||
msglen_bytes = await self.ps_reader.readexactly(4)
|
||||
(msglen,) = struct.unpack("!L", msglen_bytes)
|
||||
payload = await self.ps_reader.readexactly(msglen - 4)
|
||||
|
||||
# AuthenticationOK
|
||||
if msgtype == b"R":
|
||||
self.state = ConnectionState.AUTHOK
|
||||
log.debug(f"[{self.conn_id}] PS -> C: AuthenticationOK ({msglen} bytes)")
|
||||
|
||||
# CopyBothresponse
|
||||
elif msgtype == b"W":
|
||||
self.state = ConnectionState.COPYBOTH
|
||||
log.debug(f"[{self.conn_id}] PS -> C: CopyBothResponse ({msglen} bytes)")
|
||||
|
||||
# CopyData
|
||||
elif msgtype == b"d":
|
||||
log.debug(f"[{self.conn_id}] PS -> C: CopyData ({msglen} bytes)")
|
||||
if self.response_cb is not None:
|
||||
await self.response_cb(self.conn_id)
|
||||
pass
|
||||
|
||||
else:
|
||||
log.debug(f"[{self.conn_id}] PS -> C: message of type '{msgtype}' ({msglen} bytes)")
|
||||
|
||||
self.compute_writer.write(msgtype)
|
||||
self.compute_writer.write(msglen_bytes)
|
||||
self.compute_writer.write(payload)
|
||||
await self.compute_writer.drain()
|
||||
|
||||
|
||||
class PageserverProxy:
|
||||
def __init__(self, listen_port: int, dest_port: int, response_cb=None):
|
||||
self.listen_port = listen_port
|
||||
self.dest_port = dest_port
|
||||
self.response_cb = response_cb
|
||||
self.conn_counter = 0
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def shutdown(self):
|
||||
self.serve_task.cancel()
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def handle_client(self, compute_reader, compute_writer):
|
||||
self.conn_counter += 1
|
||||
conn_id = self.conn_counter
|
||||
conn = Connection(
|
||||
conn_id,
|
||||
compute_reader,
|
||||
compute_writer,
|
||||
self.shutdown_event,
|
||||
self.dest_port,
|
||||
self.response_cb,
|
||||
)
|
||||
await conn.run()
|
||||
|
||||
async def run_server(self):
|
||||
log.info("run_server called")
|
||||
server = await asyncio.start_server(self.handle_client, "localhost", self.listen_port)
|
||||
|
||||
self.serve_task = asyncio.create_task(server.serve_forever())
|
||||
|
||||
try:
|
||||
await self.serve_task
|
||||
except asyncio.CancelledError:
|
||||
log.info("proxy shutting down")
|
||||
|
||||
def launch_server_in_thread(self):
|
||||
t1 = threading.Thread(target=asyncio.run, args=self.run_server)
|
||||
t1.start()
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from asyncio import TaskGroup
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import NeonEnvBuilder
|
||||
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin
|
||||
from fixtures.pageserver_mitm import BreakConnectionException, PageserverProxy
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
|
||||
|
||||
@pytest.mark.timeout(600)
|
||||
@@ -80,3 +84,197 @@ def test_compute_pageserver_connection_stress(neon_env_builder: NeonEnvBuilder):
|
||||
# do a graceful shutdown which would had caught the allowed_errors before
|
||||
# https://github.com/neondatabase/neon/pull/8632
|
||||
env.pageserver.stop()
|
||||
|
||||
|
||||
#
|
||||
# Observations:
|
||||
#
|
||||
# 1. If the backend is waiting for response to GetPage request, and the client disconnects,
|
||||
# the backend will not immediately abort the GetPage request. It will not notice that the
|
||||
# client is gone, until it tries to send something back to the client, or if a timeout
|
||||
# kills the query.
|
||||
#
|
||||
# So to reproduce the traffic jam, you need:
|
||||
#
|
||||
# - A network glitch, which causes one GetPage request/response to be lost or delayed.
|
||||
# It might get lost at IP level, and TCP retransmits might take a long time. Or there might
|
||||
# be a glitch in the pageserver or compute, which causes the request to be "stuck".
|
||||
#
|
||||
# - An application with a application level timeout and retry. If the
|
||||
# query doesn't return in a timely a fashion, the application kills the connection and
|
||||
# retries the query, or a runs similar query that needs the same page.
|
||||
#
|
||||
# The first time the GetPage request is stuck and it disconnects, it leaves behind a
|
||||
# backend that's still waiting for the GetPage response, and is holding the buffer lock.
|
||||
# The client has closed the connection, but the server doesn't get the memo.
|
||||
# On each subsequent retry, the connection will block waiting for the buffer lock, give up,
|
||||
# and leave behind another backend blocked indefinitely.
|
||||
#
|
||||
# The situation unravels when the original backend doing the GetPage request finally
|
||||
# gets a response, or it gets confirmation that the TCP connection is lost.
|
||||
#
|
||||
# This test reproduces the traffic jam using a MITM proxy between pageserver and compute,
|
||||
# and forcing one GetPage request to get stuck.
|
||||
#
|
||||
# Recommendations:
|
||||
# - set client_connection_check_interval = '10s'. This makes Postgres wake up and check
|
||||
# for client connection loss. It's not perfect, it might not notice if the client has
|
||||
# e.g rebooted without sending a RST packet, but there's no downside
|
||||
#
|
||||
# - Add a timeout to GetPage requests. If no response is received from the pageserver
|
||||
# in, say, 10 s, terminate the pageserver connection and retry. XXX: Negotiate this
|
||||
# behavior with the storage team
|
||||
#
|
||||
#
|
||||
@pytest.mark.timeout(120)
|
||||
def test_compute_pageserver_connection_stress2(
|
||||
neon_env_builder: NeonEnvBuilder, port_distributor: PortDistributor, pg_bin: PgBin
|
||||
):
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# Set up the MITM proxy
|
||||
|
||||
global error_fraction
|
||||
global delay_fraction
|
||||
|
||||
error_fraction = 0
|
||||
delay_fraction = 0
|
||||
|
||||
async def response_cb(conn_id):
|
||||
global delay_fraction
|
||||
global error_fraction
|
||||
|
||||
if random.random() < error_fraction:
|
||||
raise BreakConnectionException("unlucky")
|
||||
|
||||
orig_delay_fraction = delay_fraction
|
||||
if random.random() < delay_fraction:
|
||||
delay_fraction = 0
|
||||
log.info(f"[{conn_id}] making getpage request STUCK")
|
||||
try:
|
||||
await asyncio.sleep(300)
|
||||
finally:
|
||||
delay_fraction = orig_delay_fraction
|
||||
log.info(f"[{conn_id}] delay finished")
|
||||
|
||||
mitm_listen_port = port_distributor.get_port()
|
||||
mitm = PageserverProxy(mitm_listen_port, env.pageserver.service_port.pg, response_cb)
|
||||
|
||||
def main():
|
||||
global error_fraction, delay_fraction
|
||||
endpoint = env.endpoints.create(
|
||||
"main",
|
||||
config_lines=[
|
||||
"max_connections=1000",
|
||||
"shared_buffers=8MB",
|
||||
"log_connections=on",
|
||||
"log_disconnections=on",
|
||||
],
|
||||
)
|
||||
endpoint.start()
|
||||
|
||||
with open(endpoint.pg_data_dir_path() / "postgresql.conf", "a") as conf:
|
||||
conf.write(
|
||||
f"neon.pageserver_connstring='postgres://localhost:{mitm_listen_port}' # MITM proxy\n"
|
||||
)
|
||||
|
||||
pg_conn = endpoint.connect()
|
||||
cur = pg_conn.cursor()
|
||||
|
||||
cur.execute("select pg_reload_conf()")
|
||||
|
||||
scale = 5
|
||||
connstr = endpoint.connstr()
|
||||
log.info(f"Start a pgbench workload on pg {connstr}")
|
||||
|
||||
error_fraction = 0.001
|
||||
|
||||
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
|
||||
error_fraction = 0.00
|
||||
delay_fraction = 0.001
|
||||
|
||||
cur.execute("select max(aid) from pgbench_accounts")
|
||||
num_accounts = 100000 * scale
|
||||
|
||||
num_clients = 200
|
||||
|
||||
app = WalkingApplication(num_accounts, num_clients, endpoint, 1000000)
|
||||
asyncio.run(app.run())
|
||||
|
||||
mitm.shutdown()
|
||||
|
||||
async def mm():
|
||||
await asyncio.gather(asyncio.to_thread(main), mitm.run_server())
|
||||
|
||||
asyncio.run(mm())
|
||||
|
||||
# do a graceful shutdown which would had caught the allowed_errors before
|
||||
# https://github.com/neondatabase/neon/pull/8632
|
||||
env.pageserver.stop()
|
||||
|
||||
|
||||
class WalkingApplication:
|
||||
"""
|
||||
A test application with following characteristics:
|
||||
|
||||
- It performs single-row lookups in pgbench_accounts table, just like pgbench -S
|
||||
|
||||
- Whenever a query takes longer than 10s, the application disconnects, reconnects,
|
||||
and retries the query, with the same parameter. This way, if there's a problem
|
||||
with a single page, the application will keep retrying it rather than work
|
||||
around it.
|
||||
|
||||
- The lookups are not randomly distributed, but form a "walking herd" pattern,
|
||||
where the queries walk through all accounts, with some randomness. This way,
|
||||
there's a lot of locality of access, but the locality moves throughout the
|
||||
table.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_accounts, num_clients, endpoint, num_xacts):
|
||||
self.num_accounts = num_accounts
|
||||
self.num_clients = num_clients
|
||||
self.endpoint = endpoint
|
||||
self.running = True
|
||||
self.num_xacts = num_xacts
|
||||
|
||||
self.xacts_started = 0
|
||||
self.xacts_performed = 0
|
||||
self.xacts_failed = 0
|
||||
|
||||
async def run(self):
|
||||
async with TaskGroup() as group:
|
||||
for i in range(1, self.num_clients):
|
||||
group.create_task(self.walking_client(i))
|
||||
|
||||
async def walking_client(self, client_id):
|
||||
local_xacts_performed = 0
|
||||
|
||||
conn = None
|
||||
stmt = None
|
||||
failed = False
|
||||
while self.running and self.xacts_started < self.num_xacts:
|
||||
self.xacts_started += 1
|
||||
if not failed:
|
||||
aid = (self.xacts_started * 100 + random.randint(0, 100)) % self.num_accounts + 1
|
||||
|
||||
if conn is None:
|
||||
conn = await self.endpoint.connect_async()
|
||||
await conn.execute("set statement_timeout=0")
|
||||
stmt = await conn.prepare("SELECT abalance FROM pgbench_accounts WHERE aid = $1")
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(10):
|
||||
res = await stmt.fetchval(aid)
|
||||
if local_xacts_performed % 1000 == 0:
|
||||
log.info(
|
||||
f"[{client_id}] result {self.xacts_performed}/{self.num_xacts}: balance of account {aid}: {res}"
|
||||
)
|
||||
self.xacts_performed += 1
|
||||
local_xacts_performed += 1
|
||||
failed = False
|
||||
except TimeoutError:
|
||||
log.info(f"[{client_id}] query on aid {aid} timed out. Reconnecting")
|
||||
conn.terminate()
|
||||
conn = None
|
||||
failed = True
|
||||
|
||||
Reference in New Issue
Block a user