mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
add script to download all remote layers (#3294)
For use in production in case on-demand download turns out to be problematic during tenant_attach, or when we eventually introduce layer eviction. Co-authored-by: Dmitry Rodionov <dmitry@neon.tech>
This commit is contained in:
committed by
GitHub
parent
01b4b0c2f3
commit
8963d830fb
324
scripts/force_layer_download.py
Normal file
324
scripts/force_layer_download.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class ClientException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, pageserver_api_endpoint: str, max_concurrent_layer_downloads: int):
|
||||
self.endpoint = pageserver_api_endpoint
|
||||
self.max_concurrent_layer_downloads = max_concurrent_layer_downloads
|
||||
self.sess = aiohttp.ClientSession()
|
||||
|
||||
async def close(self):
|
||||
await self.sess.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_t, exc_v, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def parse_response(self, resp, expected_type):
|
||||
body = await resp.json()
|
||||
if not resp.ok:
|
||||
raise ClientException(f"Response: {resp} Body: {body}")
|
||||
|
||||
if not isinstance(body, expected_type):
|
||||
raise ClientException(f"expecting {expected_type.__name__}")
|
||||
return body
|
||||
|
||||
async def get_tenant_ids(self):
|
||||
resp = await self.sess.get(f"{self.endpoint}/v1/tenant")
|
||||
payload = await self.parse_response(resp=resp, expected_type=list)
|
||||
return [t["id"] for t in payload]
|
||||
|
||||
async def get_timeline_ids(self, tenant_id):
|
||||
resp = await self.sess.get(f"{self.endpoint}/v1/tenant/{tenant_id}/timeline")
|
||||
payload = await self.parse_response(resp=resp, expected_type=list)
|
||||
return [t["timeline_id"] for t in payload]
|
||||
|
||||
async def timeline_spawn_download_remote_layers(self, tenant_id, timeline_id, ongoing_ok=False):
|
||||
resp = await self.sess.post(
|
||||
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
|
||||
json={"max_concurrent_downloads": self.max_concurrent_layer_downloads},
|
||||
)
|
||||
body = await resp.json()
|
||||
if resp.status == 409:
|
||||
if not ongoing_ok:
|
||||
raise ClientException("download already ongoing")
|
||||
# response body has same shape for ongoing and newly created
|
||||
elif not resp.ok:
|
||||
raise ClientException(f"Response: {resp} Body: {body}")
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise ClientException("expecting dict")
|
||||
|
||||
return body
|
||||
|
||||
async def timeline_poll_download_remote_layers_status(
|
||||
self,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
):
|
||||
resp = await self.sess.get(
|
||||
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
|
||||
)
|
||||
body = await resp.json()
|
||||
|
||||
if resp.status == 404:
|
||||
return None
|
||||
elif not resp.ok:
|
||||
raise ClientException(f"Response: {resp} Body: {body}")
|
||||
|
||||
return body
|
||||
|
||||
|
||||
@dataclass
|
||||
class Completed:
|
||||
"""The status dict returned by the API"""
|
||||
|
||||
status: Dict[str, Any]
|
||||
|
||||
|
||||
sigint_received = asyncio.Event()
|
||||
|
||||
|
||||
async def do_timeline(client: Client, tenant_id, timeline_id):
|
||||
"""
|
||||
Spawn download_remote_layers task for given timeline,
|
||||
then poll until the download has reached a terminal state.
|
||||
|
||||
If the terminal state is not 'Completed', the method raises an exception.
|
||||
The caller is responsible for inspecting `failed_download_count`.
|
||||
|
||||
If there is already a task going on when this method is invoked,
|
||||
it raises an exception.
|
||||
"""
|
||||
|
||||
# Don't start new downloads if user pressed SIGINT.
|
||||
# This task will show up as "raised_exception" in the report.
|
||||
if sigint_received.is_set():
|
||||
raise Exception("not starting because SIGINT received")
|
||||
|
||||
# run downloads to completion
|
||||
|
||||
status = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
|
||||
if status is not None and status["state"] == "Running":
|
||||
raise Exception("download is already running")
|
||||
|
||||
spawned = await client.timeline_spawn_download_remote_layers(
|
||||
tenant_id, timeline_id, ongoing_ok=False
|
||||
)
|
||||
|
||||
while True:
|
||||
st = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
|
||||
logging.info(f"{tenant_id}:{timeline_id} state is: {st}")
|
||||
|
||||
if spawned["task_id"] != st["task_id"]:
|
||||
raise ClientException("download task ids changed while polling")
|
||||
|
||||
if st["state"] == "Running":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
|
||||
if st["state"] != "Completed":
|
||||
raise ClientException(
|
||||
f"download task reached terminal state != Completed: {st['state']}"
|
||||
)
|
||||
|
||||
return Completed(st)
|
||||
|
||||
|
||||
def handle_sigint():
|
||||
logging.info("SIGINT received, asyncio event set. Will not start new downloads.")
|
||||
global sigint_received
|
||||
sigint_received.set()
|
||||
|
||||
|
||||
async def main(args):
|
||||
async with Client(args.pageserver_http_endpoint, args.max_concurrent_layer_downloads) as client:
|
||||
exit_code = await main_impl(args, args.report_output, client)
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
async def taskq_handler(task_q, result_q):
|
||||
while True:
|
||||
try:
|
||||
(id, fut) = task_q.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
logging.debug("taskq_handler observed empty task_q, returning")
|
||||
return
|
||||
logging.info(f"starting task {id}")
|
||||
try:
|
||||
res = await fut
|
||||
except Exception as e:
|
||||
res = e
|
||||
result_q.put_nowait((id, res))
|
||||
|
||||
|
||||
async def print_progress(result_q, tasks):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
logging.info(f"{result_q.qsize()} / {len(tasks)} tasks done")
|
||||
|
||||
|
||||
async def main_impl(args, report_out, client: Client):
|
||||
"""
|
||||
Returns OS exit status.
|
||||
"""
|
||||
tenant_and_timline_ids: List[Tuple[str, str]] = []
|
||||
# fill tenant_and_timline_ids based on spec
|
||||
for spec in args.what:
|
||||
comps = spec.split(":")
|
||||
if comps == ["ALL"]:
|
||||
logging.info("get tenant list")
|
||||
tenant_ids = await client.get_tenant_ids()
|
||||
get_timeline_id_coros = [client.get_timeline_ids(tenant_id) for tenant_id in tenant_ids]
|
||||
gathered = await asyncio.gather(*get_timeline_id_coros, return_exceptions=True)
|
||||
assert len(tenant_ids) == len(gathered)
|
||||
tenant_and_timline_ids = []
|
||||
for tid, tlids in zip(tenant_ids, gathered):
|
||||
for tlid in tlids:
|
||||
tenant_and_timline_ids.append((tid, tlid))
|
||||
elif len(comps) == 1:
|
||||
tid = comps[0]
|
||||
tlids = await client.get_timeline_ids(tid)
|
||||
for tlid in tlids:
|
||||
tenant_and_timline_ids.append((tid, tlid))
|
||||
elif len(comps) == 2:
|
||||
tenant_and_timline_ids.append((comps[0], comps[1]))
|
||||
else:
|
||||
raise ValueError(f"invalid what-spec: {spec}")
|
||||
|
||||
logging.info("expanded spec:")
|
||||
for tid, tlid in tenant_and_timline_ids:
|
||||
logging.info(f"{tid}:{tlid}")
|
||||
|
||||
logging.info("remove duplicates after expanding spec")
|
||||
tmp = list(set(tenant_and_timline_ids))
|
||||
assert len(tmp) <= len(tenant_and_timline_ids)
|
||||
if len(tmp) != len(tenant_and_timline_ids):
|
||||
logging.info(f"spec had {len(tenant_and_timline_ids) - len(tmp)} duplicates")
|
||||
tenant_and_timline_ids = tmp
|
||||
|
||||
logging.info("create tasks and process them at specified concurrency")
|
||||
task_q: asyncio.Queue[Tuple[str, Awaitable[Any]]] = asyncio.Queue()
|
||||
tasks = {
|
||||
f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids
|
||||
}
|
||||
for task in tasks.items():
|
||||
task_q.put_nowait(task)
|
||||
|
||||
result_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue()
|
||||
taskq_handlers = []
|
||||
for _ in range(0, args.concurrent_tasks):
|
||||
taskq_handlers.append(taskq_handler(task_q, result_q))
|
||||
|
||||
print_progress_task = asyncio.create_task(print_progress(result_q, tasks))
|
||||
|
||||
await asyncio.gather(*taskq_handlers)
|
||||
print_progress_task.cancel()
|
||||
|
||||
logging.info("all tasks handled, generating report")
|
||||
|
||||
results = []
|
||||
while True:
|
||||
try:
|
||||
results.append(result_q.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
assert task_q.empty()
|
||||
|
||||
report = defaultdict(list)
|
||||
for id, result in results:
|
||||
logging.info(f"result for {id}: {result}")
|
||||
if isinstance(result, Completed):
|
||||
if result.status["failed_download_count"] == 0:
|
||||
report["completed_without_errors"].append(id)
|
||||
else:
|
||||
report["completed_with_download_errors"].append(id)
|
||||
elif isinstance(result, Exception):
|
||||
report["raised_exception"].append(id)
|
||||
else:
|
||||
raise ValueError("unexpected result type")
|
||||
json.dump(report, report_out)
|
||||
|
||||
logging.info("--------------------------------------------------------------------------------")
|
||||
|
||||
report_success = len(report["completed_without_errors"]) == len(tenant_and_timline_ids)
|
||||
if not report_success:
|
||||
logging.error("One or more tasks encountered errors.")
|
||||
else:
|
||||
logging.info("All tasks reported success.")
|
||||
logging.info("Inspect log for details and report file for JSON summary.")
|
||||
|
||||
return report_success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--report-output",
|
||||
type=argparse.FileType("w"),
|
||||
default="-",
|
||||
help="where to write report output (default: stdout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pageserver-http-endpoint",
|
||||
default="http://localhost:9898",
|
||||
help="pageserver http endpoint, (default http://localhost:9898)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrent-tasks",
|
||||
required=False,
|
||||
default=5,
|
||||
type=int,
|
||||
help="Max concurrent download tasks created & polled by this script",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent-layer-downloads",
|
||||
dest="max_concurrent_layer_downloads",
|
||||
required=False,
|
||||
default=8,
|
||||
type=int,
|
||||
help="Max concurrent download tasks spawned by pageserver. Each layer is a separate task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"what",
|
||||
nargs="+",
|
||||
help="what to download: ALL|tenant_id|tenant_id:timeline_id",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="enable verbose logging",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
level = logging.INFO
|
||||
if args.verbose:
|
||||
level = logging.DEBUG
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||
datefmt="%Y-%m-%d:%H:%M:%S",
|
||||
level=level,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, handle_sigint)
|
||||
sys.exit(asyncio.run(main(args)))
|
||||
Reference in New Issue
Block a user