mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
Improve some typing in test_runner
Fixes some types, adds some types, and adds some override annotations. Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
@@ -18,6 +18,7 @@ from urllib.parse import urlencode
|
||||
import allure
|
||||
import zstandard
|
||||
from psycopg2.extensions import cursor
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.pageserver.common_types import (
|
||||
@@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
IO,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from typing import IO, Optional
|
||||
|
||||
from fixtures.common_types import TimelineId
|
||||
from fixtures.neon_fixtures import PgBin
|
||||
from fixtures.common_types import TimelineId
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
Fn = TypeVar("Fn", bound=Callable[..., Any])
|
||||
|
||||
@@ -42,12 +43,12 @@ def subprocess_capture(
|
||||
capture_dir: Path,
|
||||
cmd: list[str],
|
||||
*,
|
||||
check=False,
|
||||
echo_stderr=False,
|
||||
echo_stdout=False,
|
||||
capture_stdout=False,
|
||||
timeout=None,
|
||||
with_command_header=True,
|
||||
check: bool = False,
|
||||
echo_stderr: bool = False,
|
||||
echo_stdout: bool = False,
|
||||
capture_stdout: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
with_command_header: bool = True,
|
||||
**popen_kwargs: Any,
|
||||
) -> tuple[str, Optional[str], int]:
|
||||
"""Run a process and bifurcate its output to files and the `log` logger
|
||||
@@ -84,6 +85,7 @@ def subprocess_capture(
|
||||
self.capture = capture
|
||||
self.captured = ""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
first = with_command_header
|
||||
for line in self.in_file:
|
||||
@@ -165,10 +167,10 @@ def global_counter() -> int:
|
||||
def print_gc_result(row: dict[str, Any]):
|
||||
log.info("GC duration {elapsed} ms".format_map(row))
|
||||
log.info(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
|
||||
row
|
||||
)
|
||||
(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
|
||||
).format_map(row)
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int:
|
||||
return round(0.06689 * size_mb - 0.5)
|
||||
|
||||
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
|
||||
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
|
||||
)
|
||||
|
||||
@@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
|
||||
|
||||
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
|
||||
"""Add links to server logs in Grafana to Allure report"""
|
||||
links = {}
|
||||
links: dict[str, str] = {}
|
||||
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
|
||||
endpoint_id, region_id, _ = host.split(".", 2)
|
||||
|
||||
@@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
|
||||
|
||||
|
||||
def start_in_background(
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
|
||||
) -> subprocess.Popen[bytes]:
|
||||
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
|
||||
|
||||
@@ -376,14 +378,11 @@ def start_in_background(
|
||||
return spawned_process
|
||||
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
def wait_until(
|
||||
number_of_iterations: int,
|
||||
interval: float,
|
||||
func: Callable[[], WaitUntilRet],
|
||||
show_intermediate_error=False,
|
||||
show_intermediate_error: bool = False,
|
||||
) -> WaitUntilRet:
|
||||
"""
|
||||
Wait until 'func' returns successfully, without exception. Returns the
|
||||
@@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float:
|
||||
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
|
||||
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
|
||||
error_or_warn = re.compile(r"\s(ERROR|WARN)")
|
||||
errors = []
|
||||
errors: list[tuple[int, str]] = []
|
||||
for lineno, line in enumerate(input, start=1):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
@@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
|
||||
return errors
|
||||
|
||||
|
||||
def assert_no_errors(log_file, service, allowed_errors):
|
||||
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
|
||||
if not log_file.exists():
|
||||
log.warning(f"Skipping {service} log check: {log_file} does not exist")
|
||||
return
|
||||
@@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum):
|
||||
V2 = "v2"
|
||||
CrossValidation = "cross-validation"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
"""
|
||||
started_at = time.time()
|
||||
|
||||
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
|
||||
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
|
||||
assert reader is not None
|
||||
digest = sha256(usedforsecurity=False)
|
||||
while True:
|
||||
@@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
right_list
|
||||
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
|
||||
|
||||
mismatching = set()
|
||||
mismatching: set[str] = set()
|
||||
|
||||
for left_tuple, right_tuple in zip(left_list, right_list):
|
||||
left_path, left_hash = left_tuple
|
||||
@@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread):
|
||||
Simple Thread wrapper with join() propagating the possible exception in the thread.
|
||||
"""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.exc = None
|
||||
try:
|
||||
@@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread):
|
||||
except BaseException as e:
|
||||
self.exc = e
|
||||
|
||||
def join(self, timeout=None):
|
||||
@override
|
||||
def join(self, timeout: Optional[float] = None) -> Any:
|
||||
super().join(timeout)
|
||||
if self.exc:
|
||||
raise self.exc
|
||||
|
||||
Reference in New Issue
Block a user