add script to download all remote layers (#3294)

For use in production in case on-demand download turns out to be
problematic during tenant_attach, or when we eventually introduce layer
eviction.

Co-authored-by: Dmitry Rodionov <dmitry@neon.tech>
This commit is contained in:
Christian Schwarz
2023-01-25 14:55:25 +01:00
committed by GitHub
parent 01b4b0c2f3
commit 8963d830fb
3 changed files with 567 additions and 5 deletions

View File

@@ -0,0 +1,324 @@
import argparse
import asyncio
import json
import logging
import signal
import sys
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Tuple
import aiohttp
class ClientException(Exception):
pass
class Client:
def __init__(self, pageserver_api_endpoint: str, max_concurrent_layer_downloads: int):
self.endpoint = pageserver_api_endpoint
self.max_concurrent_layer_downloads = max_concurrent_layer_downloads
self.sess = aiohttp.ClientSession()
async def close(self):
await self.sess.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_t, exc_v, exc_tb):
await self.close()
async def parse_response(self, resp, expected_type):
body = await resp.json()
if not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
if not isinstance(body, expected_type):
raise ClientException(f"expecting {expected_type.__name__}")
return body
async def get_tenant_ids(self):
resp = await self.sess.get(f"{self.endpoint}/v1/tenant")
payload = await self.parse_response(resp=resp, expected_type=list)
return [t["id"] for t in payload]
async def get_timeline_ids(self, tenant_id):
resp = await self.sess.get(f"{self.endpoint}/v1/tenant/{tenant_id}/timeline")
payload = await self.parse_response(resp=resp, expected_type=list)
return [t["timeline_id"] for t in payload]
async def timeline_spawn_download_remote_layers(self, tenant_id, timeline_id, ongoing_ok=False):
resp = await self.sess.post(
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
json={"max_concurrent_downloads": self.max_concurrent_layer_downloads},
)
body = await resp.json()
if resp.status == 409:
if not ongoing_ok:
raise ClientException("download already ongoing")
# response body has same shape for ongoing and newly created
elif not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
if not isinstance(body, dict):
raise ClientException("expecting dict")
return body
async def timeline_poll_download_remote_layers_status(
self,
tenant_id,
timeline_id,
):
resp = await self.sess.get(
f"{self.endpoint}/v1/tenant/{tenant_id}/timeline/{timeline_id}/download_remote_layers",
)
body = await resp.json()
if resp.status == 404:
return None
elif not resp.ok:
raise ClientException(f"Response: {resp} Body: {body}")
return body
@dataclass
class Completed:
"""The status dict returned by the API"""
status: Dict[str, Any]
sigint_received = asyncio.Event()
async def do_timeline(client: Client, tenant_id, timeline_id):
"""
Spawn download_remote_layers task for given timeline,
then poll until the download has reached a terminal state.
If the terminal state is not 'Completed', the method raises an exception.
The caller is responsible for inspecting `failed_download_count`.
If there is already a task going on when this method is invoked,
it raises an exception.
"""
# Don't start new downloads if user pressed SIGINT.
# This task will show up as "raised_exception" in the report.
if sigint_received.is_set():
raise Exception("not starting because SIGINT received")
# run downloads to completion
status = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
if status is not None and status["state"] == "Running":
raise Exception("download is already running")
spawned = await client.timeline_spawn_download_remote_layers(
tenant_id, timeline_id, ongoing_ok=False
)
while True:
st = await client.timeline_poll_download_remote_layers_status(tenant_id, timeline_id)
logging.info(f"{tenant_id}:{timeline_id} state is: {st}")
if spawned["task_id"] != st["task_id"]:
raise ClientException("download task ids changed while polling")
if st["state"] == "Running":
await asyncio.sleep(10)
continue
if st["state"] != "Completed":
raise ClientException(
f"download task reached terminal state != Completed: {st['state']}"
)
return Completed(st)
def handle_sigint():
logging.info("SIGINT received, asyncio event set. Will not start new downloads.")
global sigint_received
sigint_received.set()
async def main(args):
async with Client(args.pageserver_http_endpoint, args.max_concurrent_layer_downloads) as client:
exit_code = await main_impl(args, args.report_output, client)
return exit_code
async def taskq_handler(task_q, result_q):
while True:
try:
(id, fut) = task_q.get_nowait()
except asyncio.QueueEmpty:
logging.debug("taskq_handler observed empty task_q, returning")
return
logging.info(f"starting task {id}")
try:
res = await fut
except Exception as e:
res = e
result_q.put_nowait((id, res))
async def print_progress(result_q, tasks):
while True:
await asyncio.sleep(10)
logging.info(f"{result_q.qsize()} / {len(tasks)} tasks done")
async def main_impl(args, report_out, client: Client):
"""
Returns OS exit status.
"""
tenant_and_timline_ids: List[Tuple[str, str]] = []
# fill tenant_and_timline_ids based on spec
for spec in args.what:
comps = spec.split(":")
if comps == ["ALL"]:
logging.info("get tenant list")
tenant_ids = await client.get_tenant_ids()
get_timeline_id_coros = [client.get_timeline_ids(tenant_id) for tenant_id in tenant_ids]
gathered = await asyncio.gather(*get_timeline_id_coros, return_exceptions=True)
assert len(tenant_ids) == len(gathered)
tenant_and_timline_ids = []
for tid, tlids in zip(tenant_ids, gathered):
for tlid in tlids:
tenant_and_timline_ids.append((tid, tlid))
elif len(comps) == 1:
tid = comps[0]
tlids = await client.get_timeline_ids(tid)
for tlid in tlids:
tenant_and_timline_ids.append((tid, tlid))
elif len(comps) == 2:
tenant_and_timline_ids.append((comps[0], comps[1]))
else:
raise ValueError(f"invalid what-spec: {spec}")
logging.info("expanded spec:")
for tid, tlid in tenant_and_timline_ids:
logging.info(f"{tid}:{tlid}")
logging.info("remove duplicates after expanding spec")
tmp = list(set(tenant_and_timline_ids))
assert len(tmp) <= len(tenant_and_timline_ids)
if len(tmp) != len(tenant_and_timline_ids):
logging.info(f"spec had {len(tenant_and_timline_ids) - len(tmp)} duplicates")
tenant_and_timline_ids = tmp
logging.info("create tasks and process them at specified concurrency")
task_q: asyncio.Queue[Tuple[str, Awaitable[Any]]] = asyncio.Queue()
tasks = {
f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids
}
for task in tasks.items():
task_q.put_nowait(task)
result_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue()
taskq_handlers = []
for _ in range(0, args.concurrent_tasks):
taskq_handlers.append(taskq_handler(task_q, result_q))
print_progress_task = asyncio.create_task(print_progress(result_q, tasks))
await asyncio.gather(*taskq_handlers)
print_progress_task.cancel()
logging.info("all tasks handled, generating report")
results = []
while True:
try:
results.append(result_q.get_nowait())
except asyncio.QueueEmpty:
break
assert task_q.empty()
report = defaultdict(list)
for id, result in results:
logging.info(f"result for {id}: {result}")
if isinstance(result, Completed):
if result.status["failed_download_count"] == 0:
report["completed_without_errors"].append(id)
else:
report["completed_with_download_errors"].append(id)
elif isinstance(result, Exception):
report["raised_exception"].append(id)
else:
raise ValueError("unexpected result type")
json.dump(report, report_out)
logging.info("--------------------------------------------------------------------------------")
report_success = len(report["completed_without_errors"]) == len(tenant_and_timline_ids)
if not report_success:
logging.error("One or more tasks encountered errors.")
else:
logging.info("All tasks reported success.")
logging.info("Inspect log for details and report file for JSON summary.")
return report_success
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--report-output",
type=argparse.FileType("w"),
default="-",
help="where to write report output (default: stdout)",
)
parser.add_argument(
"--pageserver-http-endpoint",
default="http://localhost:9898",
help="pageserver http endpoint, (default http://localhost:9898)",
)
parser.add_argument(
"--concurrent-tasks",
required=False,
default=5,
type=int,
help="Max concurrent download tasks created & polled by this script",
)
parser.add_argument(
"--max-concurrent-layer-downloads",
dest="max_concurrent_layer_downloads",
required=False,
default=8,
type=int,
help="Max concurrent download tasks spawned by pageserver. Each layer is a separate task.",
)
parser.add_argument(
"what",
nargs="+",
help="what to download: ALL|tenant_id|tenant_id:timeline_id",
)
parser.add_argument(
"--verbose",
action="store_true",
help="enable verbose logging",
)
args = parser.parse_args()
level = logging.INFO
if args.verbose:
level = logging.DEBUG
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=level,
)
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, handle_sigint)
sys.exit(asyncio.run(main(args)))