Files
neon/scripts/force_layer_download.py
Alexander Bayandin 30a7dd630c ruff: enable TC — flake8-type-checking (#11368)
## Problem

`TYPE_CHECKING` is used inconsistently across Python tests.

## Summary of changes
- Update `ruff`: 0.7.0 -> 0.11.2
- Enable TC (flake8-type-checking):
https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc
- (auto)fix all new issues
2025-03-30 18:58:33 +00:00

333 lines
11 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import json
import logging
import signal
import sys
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import aiohttp
if TYPE_CHECKING:
from collections.abc import Awaitable
from typing import Any
class ClientException(Exception):
pass
class Client:
def __init__(self, pageserver_api_endpoint: str, max_concurrent_layer_downloads: int):
self.endpoint = pageserver_api_endpoint
self.max_concurrent_layer_downloads = max_concurrent_layer_downloads
self.sess = aiohttp.ClientSession()
async def close(self):
await self.sess.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
await self.close()
async def parse_response(self, resp, expected_type):
body = await resp.json()
if not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
if not isinstance(body, expected_type):
raise ClientException(f"expecting {expected_type.__name__}")
return body
async def get_tenant_ids(self):
resp = await self.sess.get(f"{self.endpoint}/v1/tenant")
payload = await self.parse_response(resp=resp, expected_type=list)
return [t["id"] for t in payload]
async def get_timeline_ids(self, tenant_id):
resp = await self.sess.get(f"{self.endpoint}/v1/tenant/{tenant_id}/timeline")
payload = await self.parse_response(resp=resp, expected_type=list)
return [t["timeline_id"] for t in payload]
async def timeline_spawn_download_remote_layers(self, tenant_id, timeline_id, ongoing_ok=False):
resp = await self.sess.post(
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
json={"max_concurrent_downloads": self.max_concurrent_layer_downloads},
)
body = await resp.json()
if resp.status == 409:
if not ongoing_ok:
raise ClientException("download already ongoing")
# response body has same shape for ongoing and newly created
elif not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
if not isinstance(body, dict):
raise ClientException("expecting dict")
return body
async def timeline_poll_download_remote_layers_status(
self,
tenant_id,
timeline_id,
):
resp = await self.sess.get(
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
)
body = await resp.json()
if resp.status == 404:
return None
elif not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
return body
@dataclass
class Completed:
"""The status dict returned by the API"""
status: dict[str, Any]
sigint_received = asyncio.Event()
async def do_timeline(client: Client, tenant_id, timeline_id):
"""
Spawn download_remote_layers task for given timeline,
then poll until the download has reached a terminal state.
If the terminal state is not 'Completed', the method raises an exception.
The caller is responsible for inspecting `failed_download_count`.
If there is already a task going on when this method is invoked,
it raises an exception.
"""
# Don't start new downloads if user pressed SIGINT.
# This task will show up as "raised_exception" in the report.
if sigint_received.is_set():
raise Exception("not starting because SIGINT received")
# run downloads to completion
status = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
if status is not None and status["state"] == "Running":
raise Exception("download is already running")
spawned = await client.timeline_spawn_download_remote_layers(
tenant_id, timeline_id, ongoing_ok=False
)
while True:
st = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
logging.info(f"{tenant_id}:{timeline_id} state is: {st}")
if spawned["task_id"] != st["task_id"]:
raise ClientException("download task ids changed while polling")
if st["state"] == "Running":
await asyncio.sleep(10)
continue
if st["state"] != "Completed":
raise ClientException(
f"download task reached terminal state != Completed: {st['state']}"
)
return Completed(st)
def handle_sigint():
logging.info("SIGINT received, asyncio event set. Will not start new downloads.")
global sigint_received
sigint_received.set()
async def main(args):
async with Client(args.pageserver_http_endpoint, args.max_concurrent_layer_downloads) as client:
exit_code = await main_impl(args, args.report_output, client)
return exit_code
async def taskq_handler(task_q, result_q):
while True:
try:
(id, fut) = task_q.get_nowait()
except asyncio.QueueEmpty:
logging.debug("taskq_handler observed empty task_q, returning")
return
logging.info(f"starting task {id}")
try:
res = await fut
except Exception as e:
res = e
result_q.put_nowait((id, res))
async def print_progress(result_q, tasks):
while True:
await asyncio.sleep(10)
logging.info(f"{result_q.qsize()} / {len(tasks)} tasks done")
async def main_impl(args, report_out, client: Client):
"""
Returns OS exit status.
"""
tenant_and_timline_ids: list[tuple[str, str]] = []
# fill tenant_and_timline_ids based on spec
for spec in args.what:
comps = spec.split(":")
if comps == ["ALL"]:
logging.info("get tenant list")
tenant_ids = await client.get_tenant_ids()
get_timeline_id_coros = [client.get_timeline_ids(tenant_id) for tenant_id in tenant_ids]
gathered = await asyncio.gather(*get_timeline_id_coros, return_exceptions=True)
tenant_and_timline_ids = []
for tid, tlids in zip(tenant_ids, gathered, strict=True):
# TODO: add error handling if tlids isinstance(Exception)
assert isinstance(tlids, list)
for tlid in tlids:
tenant_and_timline_ids.append((tid, tlid))
elif len(comps) == 1:
tid = comps[0]
tlids = await client.get_timeline_ids(tid)
for tlid in tlids:
tenant_and_timline_ids.append((tid, tlid))
elif len(comps) == 2:
tenant_and_timline_ids.append((comps[0], comps[1]))
else:
raise ValueError(f"invalid what-spec: {spec}")
logging.info("expanded spec:")
for tid, tlid in tenant_and_timline_ids:
logging.info(f"{tid}:{tlid}")
logging.info("remove duplicates after expanding spec")
tmp = list(set(tenant_and_timline_ids))
assert len(tmp) <= len(tenant_and_timline_ids)
if len(tmp) != len(tenant_and_timline_ids):
logging.info(f"spec had {len(tenant_and_timline_ids) - len(tmp)} duplicates")
tenant_and_timline_ids = tmp
logging.info("create tasks and process them at specified concurrency")
task_q: asyncio.Queue[tuple[str, Awaitable[Any]]] = asyncio.Queue()
tasks = {
f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids
}
for task in tasks.items():
task_q.put_nowait(task)
result_q: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
taskq_handlers = []
for _ in range(0, args.concurrent_tasks):
taskq_handlers.append(taskq_handler(task_q, result_q))
print_progress_task = asyncio.create_task(print_progress(result_q, tasks))
await asyncio.gather(*taskq_handlers)
print_progress_task.cancel()
logging.info("all tasks handled, generating report")
results = []
while True:
try:
results.append(result_q.get_nowait())
except asyncio.QueueEmpty:
break
assert task_q.empty()
report = defaultdict(list)
for id, result in results:
logging.info(f"result for {id}: {result}")
if isinstance(result, Completed):
if result.status["failed_download_count"] == 0:
report["completed_without_errors"].append(id)
else:
report["completed_with_download_errors"].append(id)
elif isinstance(result, Exception):
report["raised_exception"].append(id)
else:
raise ValueError("unexpected result type")
json.dump(report, report_out)
logging.info("--------------------------------------------------------------------------------")
report_success = len(report["completed_without_errors"]) == len(tenant_and_timline_ids)
if not report_success:
logging.error("One or more tasks encountered errors.")
else:
logging.info("All tasks reported success.")
logging.info("Inspect log for details and report file for JSON summary.")
return report_success
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--report-output",
type=argparse.FileType("w"),
default="-",
help="where to write report output (default: stdout)",
)
parser.add_argument(
"--pageserver-http-endpoint",
default="http://localhost:9898",
help="pageserver http endpoint, (default http://localhost:9898)",
)
parser.add_argument(
"--concurrent-tasks",
required=False,
default=5,
type=int,
help="Max concurrent download tasks created & polled by this script",
)
parser.add_argument(
"--max-concurrent-layer-downloads",
dest="max_concurrent_layer_downloads",
required=False,
default=8,
type=int,
help="Max concurrent download tasks spawned by pageserver. Each layer is a separate task.",
)
parser.add_argument(
"what",
nargs="+",
help="what to download: ALL|tenant_id|tenant_id:timeline_id",
)
parser.add_argument(
"--verbose",
action="store_true",
help="enable verbose logging",
)
args = parser.parse_args()
level = logging.INFO
if args.verbose:
level = logging.DEBUG
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=level,
)
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, handle_sigint)
sys.exit(asyncio.run(main(args)))