From d3464584a6fb03b9df264f32c18963388808ba2e Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Wed, 9 Oct 2024 15:42:22 -0500 Subject: [PATCH] Improve some typing in test_runner Fixes some types, adds some types, and adds some override annotations. Signed-off-by: Tristan Partin --- test_runner/fixtures/common_types.py | 37 +++++++++--- test_runner/fixtures/compare_fixtures.py | 36 ++++++++++- test_runner/fixtures/compute_reconfigure.py | 20 ++++--- test_runner/fixtures/flaky.py | 2 + test_runner/fixtures/httpserver.py | 13 +++- test_runner/fixtures/log_helper.py | 2 +- test_runner/fixtures/metrics.py | 4 +- test_runner/fixtures/neon_api.py | 12 ++-- test_runner/fixtures/neon_fixtures.py | 6 +- test_runner/fixtures/overlayfs.py | 5 +- test_runner/fixtures/parametrize.py | 7 ++- test_runner/fixtures/pg_version.py | 12 +++- test_runner/fixtures/port_distributor.py | 7 +-- test_runner/fixtures/remote_storage.py | 24 ++++---- .../fixtures/storage_controller_proxy.py | 9 ++- test_runner/fixtures/utils.py | 59 ++++++++++--------- test_runner/fixtures/workload.py | 11 +++- test_runner/regress/test_compaction.py | 6 +- .../regress/test_pageserver_generations.py | 6 +- test_runner/regress/test_sharding.py | 33 +++++++++-- .../regress/test_storage_controller.py | 2 +- test_runner/regress/test_storage_scrubber.py | 5 +- 22 files changed, 216 insertions(+), 102 deletions(-) diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index 3022c0279f..0ea7148f50 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -6,6 +6,8 @@ from enum import Enum from functools import total_ordering from typing import TYPE_CHECKING, TypeVar +from typing_extensions import override + if TYPE_CHECKING: from typing import Any, Union @@ -31,33 +33,36 @@ class Lsn: self.lsn_int = (int(left, 16) << 32) + int(right, 16) assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF + @override def __str__(self) -> str: """Convert lsn from int to standard hex notation.""" return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}" + @override def __repr__(self) -> str: return f'Lsn("{str(self)}")' def __int__(self) -> int: return self.lsn_int - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int < other.lsn_int - def __gt__(self, other: Any) -> bool: + def __gt__(self, other: object) -> bool: if not isinstance(other, Lsn): raise NotImplementedError return self.lsn_int > other.lsn_int - def __eq__(self, other: Any) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int == other.lsn_int # Returns the difference between two Lsns, in bytes - def __sub__(self, other: Any) -> int: + def __sub__(self, other: object) -> int: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int - other.lsn_int @@ -70,6 +75,7 @@ class Lsn: else: raise NotImplementedError + @override def __hash__(self) -> int: return hash(self.lsn_int) @@ -116,19 +122,22 @@ class Id: self.id = bytearray.fromhex(x) assert len(self.id) == 16 + @override def __str__(self) -> str: return self.id.hex() - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id < other.id - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id == other.id + @override def __hash__(self) -> int: return hash(str(self.id)) @@ -139,25 +148,31 @@ class Id: class TenantId(Id): + @override def __repr__(self) -> str: return f'`TenantId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class NodeId(Id): + @override def __repr__(self) -> str: return f'`NodeId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class TimelineId(Id): + @override def __repr__(self) -> str: return f'TimelineId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() @@ -187,7 +202,7 @@ class TenantShardId: assert self.shard_number < self.shard_count or self.shard_count == 0 @classmethod - def parse(cls: type[TTenantShardId], input) -> TTenantShardId: + def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId: if len(input) == 32: return cls( tenant_id=TenantId(input), @@ -203,6 +218,7 @@ class TenantShardId: else: raise ValueError(f"Invalid TenantShardId '{input}'") + @override def __str__(self): if self.shard_count > 0: return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}" @@ -210,22 +226,25 @@ class TenantShardId: # Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id) return str(self.tenant_id) + @override def __repr__(self): return self.__str__() def _tuple(self) -> tuple[TenantId, int, int]: return (self.tenant_id, self.shard_number, self.shard_count) - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() < other._tuple() - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() == other._tuple() + @override def __hash__(self) -> int: return hash(self._tuple()) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index ce191ac91c..2195ae8225 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager # Type-related stuff from pathlib import Path +from typing import TYPE_CHECKING import pytest from _pytest.fixtures import FixtureRequest +from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log @@ -24,6 +26,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pg_stats import PgStatTable +if TYPE_CHECKING: + from collections.abc import Iterator + class PgCompare(ABC): """Common interface of all postgres implementations, useful for benchmarks. @@ -65,12 +70,12 @@ class PgCompare(ABC): @contextmanager @abstractmethod - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str): pass @contextmanager @abstractmethod - def record_duration(self, out_name): + def record_duration(self, out_name: str): pass @contextmanager @@ -122,28 +127,34 @@ class NeonCompare(PgCompare): self._pg = self.env.endpoints.create_start("main", "main", self.tenant) @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg_bin + @override def flush(self, compact: bool = True, gc: bool = True): wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline) self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact) if gc: self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0) + @override def compact(self): self.pageserver_http_client.timeline_compact( self.tenant, self.timeline, wait_until_uploaded=True ) + @override def report_peak_memory_use(self): self.zenbenchmark.record( "peak_mem", @@ -152,6 +163,7 @@ class NeonCompare(PgCompare): report=MetricReport.LOWER_IS_BETTER, ) + @override def report_size(self): timeline_size = self.zenbenchmark.get_timeline_size( self.env.repo_dir, self.tenant, self.timeline @@ -185,9 +197,11 @@ class NeonCompare(PgCompare): "num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER ) + @override def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -211,26 +225,33 @@ class VanillaCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> VanillaPostgres: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin + @override def flush(self, compact: bool = False, gc: bool = False): self.cur.execute("checkpoint") + @override def compact(self): pass + @override def report_peak_memory_use(self): pass # TODO find something + @override def report_size(self): data_size = self.pg.get_subdir_size(Path("base")) self.zenbenchmark.record( @@ -245,6 +266,7 @@ class VanillaCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -261,28 +283,35 @@ class RemoteCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin - def flush(self): + @override + def flush(self, compact: bool = False, gc: bool = False): # TODO: flush the remote pageserver pass + @override def compact(self): pass + @override def report_peak_memory_use(self): # TODO: get memory usage from remote pageserver pass + @override def report_size(self): # TODO: get storage size from remote pageserver pass @@ -291,6 +320,7 @@ class RemoteCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) diff --git a/test_runner/fixtures/compute_reconfigure.py b/test_runner/fixtures/compute_reconfigure.py index d2305ea431..6354b7f833 100644 --- a/test_runner/fixtures/compute_reconfigure.py +++ b/test_runner/fixtures/compute_reconfigure.py @@ -1,27 +1,31 @@ from __future__ import annotations import concurrent.futures -from typing import Any +from typing import TYPE_CHECKING import pytest +from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response from fixtures.common_types import TenantId from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Callable, Optional + class ComputeReconfigure: - def __init__(self, server): + def __init__(self, server: HTTPServer): self.server = server self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach" - self.workloads = {} - self.on_notify = None + self.workloads: dict[TenantId, Any] = {} + self.on_notify: Optional[Callable[[Any], None]] = None - def register_workload(self, workload): + def register_workload(self, workload: Any): self.workloads[workload.tenant_id] = workload - def register_on_notify(self, fn): + def register_on_notify(self, fn: Optional[Callable[[Any], None]]): """ Add some extra work during a notification, like sleeping to slow things down, or logging what was notified. @@ -30,7 +34,7 @@ class ComputeReconfigure: @pytest.fixture(scope="function") -def compute_reconfigure_listener(make_httpserver): +def compute_reconfigure_listener(make_httpserver: HTTPServer): """ This fixture exposes an HTTP listener for the storage controller to submit compute notifications to us, instead of updating neon_local endpoints itself. @@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver): # accept a healthy rate of calls into notify-attach. reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1) - def handler(request: Request): + def handler(request: Request) -> Response: assert request.json is not None body: dict[str, Any] = request.json log.info(f"notify-attach request: {body}") diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py index 4ca87520a0..01634a29c5 100644 --- a/test_runner/fixtures/flaky.py +++ b/test_runner/fixtures/flaky.py @@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels from fixtures.log_helper import log if TYPE_CHECKING: + from collections.abc import MutableMapping from typing import Any + """ The plugin reruns flaky tests. It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` diff --git a/test_runner/fixtures/httpserver.py b/test_runner/fixtures/httpserver.py index 9d5b5d6422..f653fd804c 100644 --- a/test_runner/fixtures/httpserver.py +++ b/test_runner/fixtures/httpserver.py @@ -1,8 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from pytest_httpserver import HTTPServer +if TYPE_CHECKING: + from collections.abc import Iterator + + from fixtures.port_distributor import PortDistributor + # TODO: mypy fails with: # Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined] # from fixtures.neon_fixtures import PortDistributor @@ -17,7 +24,7 @@ def httpserver_ssl_context(): @pytest.fixture(scope="function") -def make_httpserver(httpserver_listen_address, httpserver_ssl_context): +def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]: host, port = httpserver_listen_address if not host: host = HTTPServer.DEFAULT_LISTEN_HOST @@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context): @pytest.fixture(scope="function") -def httpserver(make_httpserver): +def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]: server = make_httpserver yield server server.clear() @pytest.fixture(scope="function") -def httpserver_listen_address(port_distributor) -> tuple[str, int]: +def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]: port = port_distributor.get_port() return ("localhost", port) diff --git a/test_runner/fixtures/log_helper.py b/test_runner/fixtures/log_helper.py index 70d76a39c4..ebf5c8d803 100644 --- a/test_runner/fixtures/log_helper.py +++ b/test_runner/fixtures/log_helper.py @@ -31,7 +31,7 @@ LOGGING = { } -def getLogger(name="root") -> logging.Logger: +def getLogger(name: str = "root") -> logging.Logger: """Method to get logger for tests. Should be used to get correctly initialized logger.""" diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index adc90a41d0..e056ea77d4 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -22,7 +22,7 @@ class Metrics: def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]: filter = filter or {} - res = [] + res: list[Sample] = [] for sample in self.metrics[name]: try: @@ -59,7 +59,7 @@ class MetricsGetter: return results[0].value def get_metrics_values( - self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False + self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False ) -> dict[str, float]: """ When fetching multiple named metrics, it is more efficient to use this diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 846a790f1f..683ea3af44 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast import requests if TYPE_CHECKING: - from typing import Any, Literal, Optional, Union + from typing import Any, Literal, Optional from fixtures.pg_version import PgVersion @@ -25,9 +25,7 @@ class NeonAPI: self.__neon_api_key = neon_api_key self.__neon_api_base_url = neon_api_base_url.strip("/") - def __request( - self, method: Union[str, bytes], endpoint: str, **kwargs: Any - ) -> requests.Response: + def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response: if "headers" not in kwargs: kwargs["headers"] = {} kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}" @@ -187,8 +185,8 @@ class NeonAPI: def get_connection_uri( self, project_id: str, - branch_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + branch_id: str | None = None, + endpoint_id: str | None = None, database_name: str = "neondb", role_name: str = "neondb_owner", pooled: bool = True, @@ -264,7 +262,7 @@ class NeonAPI: class NeonApiEndpoint: - def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]): + def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None): self.neon_api = neon_api if project_id is None: project = neon_api.create_project(pg_version) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 5cb9821476..f81bc3f5a6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3657,7 +3657,7 @@ class Endpoint(PgProtocol, LogUtils): config_lines: Optional[list[str]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, - allow_multiple=False, + allow_multiple: bool = False, basebackup_request_tries: Optional[int] = None, ) -> Endpoint: """ @@ -3998,7 +3998,7 @@ class Safekeeper(LogUtils): def timeline_dir(self, tenant_id, timeline_id) -> Path: return self.data_dir / str(tenant_id) / str(timeline_id) - # List partial uploaded segments of this safekeeper. Works only for + # list partial uploaded segments of this safekeeper. Works only for # RemoteStorageKind.LOCAL_FS. def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId): tline_path = ( @@ -4293,7 +4293,7 @@ def pytest_addoption(parser: Parser): ) -SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile( r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)" ) diff --git a/test_runner/fixtures/overlayfs.py b/test_runner/fixtures/overlayfs.py index e0ebfeb8f4..ea11cd272c 100644 --- a/test_runner/fixtures/overlayfs.py +++ b/test_runner/fixtures/overlayfs.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Iterator from pathlib import Path +from typing import TYPE_CHECKING import psutil +if TYPE_CHECKING: + from collections.abc import Iterator + def iter_mounts_beneath(topdir: Path) -> Iterator[Path]: """ diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 3bbac4b8ee..4114c2fcb3 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -9,7 +9,12 @@ import toml from _pytest.python import Metafunc from fixtures.pg_version import PgVersion -from fixtures.utils import AuxFileStore + +if TYPE_CHECKING: + from typing import Any, Optional + + from fixtures.utils import AuxFileStore + if TYPE_CHECKING: from typing import Any, Optional diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 5820b50a46..01f0245665 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -2,9 +2,14 @@ from __future__ import annotations import enum import os -from typing import Optional +from typing import TYPE_CHECKING import pytest +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Optional + """ This fixture is used to determine which version of Postgres to use for tests. @@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum): NOT_SET = "<-POSTRGRES VERSION IS NOT SET->" # Make it less confusing in logs + @override def __repr__(self) -> str: return f"'{self.value}'" # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums + @override def __str__(self) -> str: return self.value @@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum): return f"v{self.value}" @classmethod - def _missing_(cls, value) -> Optional[PgVersion]: + @override + def _missing_(cls, value: object) -> Optional[PgVersion]: known_values = {v.value for _, v in cls.__members__.items()} # Allow passing version as a string with "v" prefix (e.g. "v14") diff --git a/test_runner/fixtures/port_distributor.py b/test_runner/fixtures/port_distributor.py index 435f452a02..df0eb2a809 100644 --- a/test_runner/fixtures/port_distributor.py +++ b/test_runner/fixtures/port_distributor.py @@ -59,10 +59,7 @@ class PortDistributor: if isinstance(value, int): return self._replace_port_int(value) - if isinstance(value, str): - return self._replace_port_str(value) - - raise TypeError(f"unsupported type {type(value)} of {value=}") + return self._replace_port_str(value) def _replace_port_int(self, value: int) -> int: known_port = self.port_map.get(value) @@ -75,7 +72,7 @@ class PortDistributor: # Use regex to find port in a string # urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432" # See https://bugs.python.org/issue27657 - ports = re.findall(r":(\d+)(?:/|$)", value) + ports: list[str] = re.findall(r":(\d+)(?:/|$)", value) assert len(ports) == 1, f"can't find port in {value}" port_int = int(ports[0]) diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index 20e6bd9318..7024953661 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -13,6 +13,7 @@ import boto3 import toml from moto.server import ThreadedMotoServer from mypy_boto3_s3 import S3Client +from typing_extensions import override from fixtures.common_types import TenantId, TenantShardId, TimelineId from fixtures.log_helper import log @@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum): EXTENSIONS = "ext" SAFEKEEPER = "safekeeper" + @override def __str__(self) -> str: return self.value @@ -81,11 +83,13 @@ class LocalFsStorage: def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path: return self.tenant_path(tenant_id) / "timelines" / str(timeline_id) - def timeline_latest_generation(self, tenant_id, timeline_id): + def timeline_latest_generation( + self, tenant_id: TenantId, timeline_id: TimelineId + ) -> Optional[int]: timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id)) index_parts = [f for f in timeline_files if f.startswith("index_part")] - def parse_gen(filename): + def parse_gen(filename: str) -> Optional[int]: log.info(f"parsing index_part '{filename}'") parts = filename.split("-") if len(parts) == 2: @@ -93,7 +97,7 @@ class LocalFsStorage: else: return None - generations = sorted([parse_gen(f) for f in index_parts]) + generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore if len(generations) == 0: raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}") return generations[-1] @@ -122,14 +126,14 @@ class LocalFsStorage: filename = f"{local_name}-{generation:08x}" return self.timeline_path(tenant_id, timeline_id) / filename - def index_content(self, tenant_id: TenantId, timeline_id: TimelineId): + def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any: with self.index_path(tenant_id, timeline_id).open("r") as f: return json.load(f) def heatmap_path(self, tenant_id: TenantId) -> Path: return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME - def heatmap_content(self, tenant_id): + def heatmap_content(self, tenant_id: TenantId) -> Any: with self.heatmap_path(tenant_id).open("r") as f: return json.load(f) @@ -297,7 +301,7 @@ class S3Storage: def heatmap_key(self, tenant_id: TenantId) -> str: return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}" - def heatmap_content(self, tenant_id: TenantId): + def heatmap_content(self, tenant_id: TenantId) -> Any: r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id)) return json.loads(r["Body"].read().decode("utf-8")) @@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum): def configure( self, repo_dir: Path, - mock_s3_server, + mock_s3_server: MockS3Server, run_id: str, test_name: str, user: RemoteStorageUser, @@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind: def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_dict() # serialize as toml inline table def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_inline_table() diff --git a/test_runner/fixtures/storage_controller_proxy.py b/test_runner/fixtures/storage_controller_proxy.py index 02cf6fc33f..c174358ef5 100644 --- a/test_runner/fixtures/storage_controller_proxy.py +++ b/test_runner/fixtures/storage_controller_proxy.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Optional +from typing import TYPE_CHECKING import pytest import requests @@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Optional + class StorageControllerProxy: def __init__(self, server: HTTPServer): @@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response: @pytest.fixture(scope="function") -def storage_controller_proxy(make_httpserver): +def storage_controller_proxy(make_httpserver: HTTPServer): """ Proxies requests into the storage controller to the currently selected storage controller instance via `StorageControllerProxy.route_to`. @@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver): log.info(f"Storage controller proxy listening on {self.listen}") - def handler(request: Request): + def handler(request: Request) -> Response: if self.route_to is None: log.info(f"Storage controller proxy has no routing configured for {request.url}") return Response("Routing not configured", status=503) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 23381e258a..ca1be35880 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -18,6 +18,7 @@ from urllib.parse import urlencode import allure import zstandard from psycopg2.extensions import cursor +from typing_extensions import override from fixtures.log_helper import log from fixtures.pageserver.common_types import ( @@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import ( ) if TYPE_CHECKING: - from typing import ( - IO, - Optional, - Union, - ) + from collections.abc import Iterable + from typing import IO, Optional + from fixtures.common_types import TimelineId from fixtures.neon_fixtures import PgBin -from fixtures.common_types import TimelineId + + WaitUntilRet = TypeVar("WaitUntilRet") + Fn = TypeVar("Fn", bound=Callable[..., Any]) @@ -42,12 +43,12 @@ def subprocess_capture( capture_dir: Path, cmd: list[str], *, - check=False, - echo_stderr=False, - echo_stdout=False, - capture_stdout=False, - timeout=None, - with_command_header=True, + check: bool = False, + echo_stderr: bool = False, + echo_stdout: bool = False, + capture_stdout: bool = False, + timeout: Optional[float] = None, + with_command_header: bool = True, **popen_kwargs: Any, ) -> tuple[str, Optional[str], int]: """Run a process and bifurcate its output to files and the `log` logger @@ -84,6 +85,7 @@ def subprocess_capture( self.capture = capture self.captured = "" + @override def run(self): first = with_command_header for line in self.in_file: @@ -165,10 +167,10 @@ def global_counter() -> int: def print_gc_result(row: dict[str, Any]): log.info("GC duration {elapsed} ms".format_map(row)) log.info( - " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" - " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map( - row - ) + ( + " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" + " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}" + ).format_map(row) ) @@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int: return round(0.06689 * size_mb - 0.5) -ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile( r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)" ) @@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz" def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int): """Add links to server logs in Grafana to Allure report""" - links = {} + links: dict[str, str] = {} # We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build endpoint_id, region_id, _ = host.split(".", 2) @@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, def start_in_background( - command: list[str], cwd: Path, log_file_name: str, is_started: Fn + command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet] ) -> subprocess.Popen[bytes]: """Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors.""" @@ -376,14 +378,11 @@ def start_in_background( return spawned_process -WaitUntilRet = TypeVar("WaitUntilRet") - - def wait_until( number_of_iterations: int, interval: float, func: Callable[[], WaitUntilRet], - show_intermediate_error=False, + show_intermediate_error: bool = False, ) -> WaitUntilRet: """ Wait until 'func' returns successfully, without exception. Returns the @@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float: def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]: # FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py error_or_warn = re.compile(r"\s(ERROR|WARN)") - errors = [] + errors: list[tuple[int, str]] = [] for lineno, line in enumerate(input, start=1): if len(line) == 0: continue @@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list return errors -def assert_no_errors(log_file, service, allowed_errors): +def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]): if not log_file.exists(): log.warning(f"Skipping {service} log check: {log_file} does not exist") return @@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum): V2 = "v2" CrossValidation = "cross-validation" + @override def __repr__(self) -> str: return f"'aux-{self.value}'" + @override def __str__(self) -> str: return f"'aux-{self.value}'" @@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str """ started_at = time.time() - def hash_extracted(reader: Union[IO[bytes], None]) -> bytes: + def hash_extracted(reader: Optional[IO[bytes]]) -> bytes: assert reader is not None digest = sha256(usedforsecurity=False) while True: @@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str right_list ), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}" - mismatching = set() + mismatching: set[str] = set() for left_tuple, right_tuple in zip(left_list, right_list): left_path, left_hash = left_tuple @@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread): Simple Thread wrapper with join() propagating the possible exception in the thread. """ + @override def run(self): self.exc = None try: @@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread): except BaseException as e: self.exc = e - def join(self, timeout=None): + @override + def join(self, timeout: Optional[float] = None) -> Any: super().join(timeout) if self.exc: raise self.exc diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index 4f9c1125bf..e869c43185 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from typing import Any, Optional +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log @@ -14,6 +14,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_for_last_record_lsn +if TYPE_CHECKING: + from typing import Any, Optional + # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. ENDPOINT_LOCK = threading.Lock() @@ -100,7 +103,7 @@ class Workload: self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id ) - def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True): + def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True): endpoint = self.endpoint(pageserver_id) start = self.expect_rows end = start + n - 1 @@ -121,7 +124,9 @@ class Workload: else: return False - def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True): + def churn_rows( + self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True + ): assert self.expect_rows >= n max_iters = 10 diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 39d4a3a6d7..420055ac3a 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -4,7 +4,7 @@ import enum import json import os import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + AGGRESIVE_COMPACTION_TENANT_CONF = { # Disable gc and compaction. The test runs compaction manually. "gc_period": "0s", diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 577a3a25ca..11ebb81023 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -15,7 +15,7 @@ import enum import os import re import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TimelineId @@ -40,6 +40,10 @@ from fixtures.remote_storage import ( from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index d1d6b3af75..b1abcaa763 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload from pytest_httpserver import HTTPServer +from typing_extensions import override from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response @@ -954,6 +955,7 @@ class PageserverFailpoint(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.allowed_errors.extend( @@ -961,19 +963,23 @@ class PageserverFailpoint(Failure): ) pageserver.http_client().configure_failpoints((self.failpoint, "return(1)")) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.http_client().configure_failpoints((self.failpoint, "off")) if self._mitigate: env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"}) + @override def expect_available(self): return True + @override def can_mitigate(self): return self._mitigate - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure): self.pageserver_id = None self.action = action + @override def apply(self, env: NeonEnv): env.storage_controller.configure_failpoints((self.failpoint, self.action)) + @override def clear(self, env: NeonEnv): if "panic" in self.action: log.info("Restarting storage controller after panic") @@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure): else: env.storage_controller.configure_failpoints((self.failpoint, "off")) + @override def expect_available(self): # Controller panics _do_ leave pageservers available, but our test code relies # on using the locate API to update configurations in Workload, so we must skip # these actions when the controller has been panicked. return "panic" not in self.action + @override def can_mitigate(self): return False - def fails_forward(self, env): + @override + def fails_forward(self, env: NeonEnv): # Edge case: the very last failpoint that simulates a DB connection error, where # the abort path will fail-forward and result in a complete split. fail_forward = self.failpoint == "shard-split-post-complete" @@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure): return fail_forward + @override def expect_exception(self): if "panic" in self.action: return requests.exceptions.ConnectionError @@ -1029,18 +1041,22 @@ class NodeKill(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.stop(immediate=True) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.start() + @override def expect_available(self): return False - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -1059,21 +1075,26 @@ class CompositeFailure(Failure): self.pageserver_id = f.pageserver_id break + @override def apply(self, env: NeonEnv): for f in self.failures: f.apply(env) - def clear(self, env): + @override + def clear(self, env: NeonEnv): for f in self.failures: f.clear(env) + @override def expect_available(self): return all(f.expect_available() for f in self.failures) - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): for f in self.failures: f.mitigate(env) + @override def expect_exception(self): expect = set(f.expect_exception() for f in self.failures) @@ -1211,7 +1232,7 @@ def test_sharding_split_failures( assert attached_count == initial_shard_count - def assert_split_done(exclude_ps_id=None) -> None: + def assert_split_done(exclude_ps_id: Optional[int] = None) -> None: secondary_count = 0 attached_count = 0 for ps in env.pageservers: diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 202634477c..7be4d2ce0c 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -1038,7 +1038,7 @@ def test_storage_controller_tenant_deletion( ) # Break the compute hook: we are checking that deletion does not depend on the compute hook being available - def break_hook(): + def break_hook(_body: Any): raise RuntimeError("Unexpected call to compute hook") compute_reconfigure_listener.register_on_notify(break_hook) diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index f999edc067..05db0fe977 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -6,7 +6,7 @@ import shutil import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + @pytest.mark.parametrize("shard_count", [None, 4]) def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):