Improve some typing in test_runner

Fixes some types, adds some types, and adds some override annotations.

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2024-10-09 15:42:22 -05:00
committed by GitHub
parent 878135fe9c
commit d3464584a6
22 changed files with 216 additions and 102 deletions

View File

@@ -18,6 +18,7 @@ from urllib.parse import urlencode
import allure
import zstandard
from psycopg2.extensions import cursor
from typing_extensions import override
from fixtures.log_helper import log
from fixtures.pageserver.common_types import (
@@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import (
)
if TYPE_CHECKING:
from typing import (
IO,
Optional,
Union,
)
from collections.abc import Iterable
from typing import IO, Optional
from fixtures.common_types import TimelineId
from fixtures.neon_fixtures import PgBin
from fixtures.common_types import TimelineId
WaitUntilRet = TypeVar("WaitUntilRet")
Fn = TypeVar("Fn", bound=Callable[..., Any])
@@ -42,12 +43,12 @@ def subprocess_capture(
capture_dir: Path,
cmd: list[str],
*,
check=False,
echo_stderr=False,
echo_stdout=False,
capture_stdout=False,
timeout=None,
with_command_header=True,
check: bool = False,
echo_stderr: bool = False,
echo_stdout: bool = False,
capture_stdout: bool = False,
timeout: Optional[float] = None,
with_command_header: bool = True,
**popen_kwargs: Any,
) -> tuple[str, Optional[str], int]:
"""Run a process and bifurcate its output to files and the `log` logger
@@ -84,6 +85,7 @@ def subprocess_capture(
self.capture = capture
self.captured = ""
@override
def run(self):
first = with_command_header
for line in self.in_file:
@@ -165,10 +167,10 @@ def global_counter() -> int:
def print_gc_result(row: dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
row
)
(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
).format_map(row)
)
@@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int:
return round(0.06689 * size_mb - 0.5)
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
)
@@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
"""Add links to server logs in Grafana to Allure report"""
links = {}
links: dict[str, str] = {}
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
endpoint_id, region_id, _ = host.split(".", 2)
@@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
def start_in_background(
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
) -> subprocess.Popen[bytes]:
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
@@ -376,14 +378,11 @@ def start_in_background(
return spawned_process
WaitUntilRet = TypeVar("WaitUntilRet")
def wait_until(
number_of_iterations: int,
interval: float,
func: Callable[[], WaitUntilRet],
show_intermediate_error=False,
show_intermediate_error: bool = False,
) -> WaitUntilRet:
"""
Wait until 'func' returns successfully, without exception. Returns the
@@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float:
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
error_or_warn = re.compile(r"\s(ERROR|WARN)")
errors = []
errors: list[tuple[int, str]] = []
for lineno, line in enumerate(input, start=1):
if len(line) == 0:
continue
@@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
return errors
def assert_no_errors(log_file, service, allowed_errors):
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
if not log_file.exists():
log.warning(f"Skipping {service} log check: {log_file} does not exist")
return
@@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum):
V2 = "v2"
CrossValidation = "cross-validation"
@override
def __repr__(self) -> str:
return f"'aux-{self.value}'"
@override
def __str__(self) -> str:
return f"'aux-{self.value}'"
@@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
"""
started_at = time.time()
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
assert reader is not None
digest = sha256(usedforsecurity=False)
while True:
@@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
right_list
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
mismatching = set()
mismatching: set[str] = set()
for left_tuple, right_tuple in zip(left_list, right_list):
left_path, left_hash = left_tuple
@@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread):
Simple Thread wrapper with join() propagating the possible exception in the thread.
"""
@override
def run(self):
self.exc = None
try:
@@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread):
except BaseException as e:
self.exc = e
def join(self, timeout=None):
@override
def join(self, timeout: Optional[float] = None) -> Any:
super().join(timeout)
if self.exc:
raise self.exc