Improve some typing in test_runner

Fixes some types, adds some types, and adds some override annotations.

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2024-10-09 15:42:22 -05:00
committed by GitHub
parent 878135fe9c
commit d3464584a6
22 changed files with 216 additions and 102 deletions

View File

@@ -6,6 +6,8 @@ from enum import Enum
from functools import total_ordering
from typing import TYPE_CHECKING, TypeVar
from typing_extensions import override
if TYPE_CHECKING:
from typing import Any, Union
@@ -31,33 +33,36 @@ class Lsn:
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
@override
def __str__(self) -> str:
"""Convert lsn from int to standard hex notation."""
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
@override
def __repr__(self) -> str:
return f'Lsn("{str(self)}")'
def __int__(self) -> int:
return self.lsn_int
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int < other.lsn_int
def __gt__(self, other: Any) -> bool:
def __gt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
raise NotImplementedError
return self.lsn_int > other.lsn_int
def __eq__(self, other: Any) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int == other.lsn_int
# Returns the difference between two Lsns, in bytes
def __sub__(self, other: Any) -> int:
def __sub__(self, other: object) -> int:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int - other.lsn_int
@@ -70,6 +75,7 @@ class Lsn:
else:
raise NotImplementedError
@override
def __hash__(self) -> int:
return hash(self.lsn_int)
@@ -116,19 +122,22 @@ class Id:
self.id = bytearray.fromhex(x)
assert len(self.id) == 16
@override
def __str__(self) -> str:
return self.id.hex()
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id < other.id
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id == other.id
@override
def __hash__(self) -> int:
return hash(str(self.id))
@@ -139,25 +148,31 @@ class Id:
class TenantId(Id):
@override
def __repr__(self) -> str:
return f'`TenantId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class NodeId(Id):
@override
def __repr__(self) -> str:
return f'`NodeId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class TimelineId(Id):
@override
def __repr__(self) -> str:
return f'TimelineId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
@@ -187,7 +202,7 @@ class TenantShardId:
assert self.shard_number < self.shard_count or self.shard_count == 0
@classmethod
def parse(cls: type[TTenantShardId], input) -> TTenantShardId:
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
if len(input) == 32:
return cls(
tenant_id=TenantId(input),
@@ -203,6 +218,7 @@ class TenantShardId:
else:
raise ValueError(f"Invalid TenantShardId '{input}'")
@override
def __str__(self):
if self.shard_count > 0:
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
@@ -210,22 +226,25 @@ class TenantShardId:
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
return str(self.tenant_id)
@override
def __repr__(self):
return self.__str__()
def _tuple(self) -> tuple[TenantId, int, int]:
return (self.tenant_id, self.shard_number, self.shard_count)
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() < other._tuple()
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() == other._tuple()
@override
def __hash__(self) -> int:
return hash(self._tuple())

View File

@@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from _pytest.fixtures import FixtureRequest
from typing_extensions import override
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.log_helper import log
@@ -24,6 +26,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pg_stats import PgStatTable
if TYPE_CHECKING:
from collections.abc import Iterator
class PgCompare(ABC):
"""Common interface of all postgres implementations, useful for benchmarks.
@@ -65,12 +70,12 @@ class PgCompare(ABC):
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str):
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name):
def record_duration(self, out_name: str):
pass
@contextmanager
@@ -122,28 +127,34 @@ class NeonCompare(PgCompare):
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg_bin
@override
def flush(self, compact: bool = True, gc: bool = True):
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
if gc:
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
@override
def compact(self):
self.pageserver_http_client.timeline_compact(
self.tenant, self.timeline, wait_until_uploaded=True
)
@override
def report_peak_memory_use(self):
self.zenbenchmark.record(
"peak_mem",
@@ -152,6 +163,7 @@ class NeonCompare(PgCompare):
report=MetricReport.LOWER_IS_BETTER,
)
@override
def report_size(self):
timeline_size = self.zenbenchmark.get_timeline_size(
self.env.repo_dir, self.tenant, self.timeline
@@ -185,9 +197,11 @@ class NeonCompare(PgCompare):
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
)
@override
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -211,26 +225,33 @@ class VanillaCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> VanillaPostgres:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
@override
def flush(self, compact: bool = False, gc: bool = False):
self.cur.execute("checkpoint")
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
pass # TODO find something
@override
def report_size(self):
data_size = self.pg.get_subdir_size(Path("base"))
self.zenbenchmark.record(
@@ -245,6 +266,7 @@ class VanillaCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -261,28 +283,35 @@ class RemoteCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
def flush(self):
@override
def flush(self, compact: bool = False, gc: bool = False):
# TODO: flush the remote pageserver
pass
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
# TODO: get memory usage from remote pageserver
pass
@override
def report_size(self):
# TODO: get storage size from remote pageserver
pass
@@ -291,6 +320,7 @@ class RemoteCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)

View File

@@ -1,27 +1,31 @@
from __future__ import annotations
import concurrent.futures
from typing import Any
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
from fixtures.common_types import TenantId
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Callable, Optional
class ComputeReconfigure:
def __init__(self, server):
def __init__(self, server: HTTPServer):
self.server = server
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
self.workloads = {}
self.on_notify = None
self.workloads: dict[TenantId, Any] = {}
self.on_notify: Optional[Callable[[Any], None]] = None
def register_workload(self, workload):
def register_workload(self, workload: Any):
self.workloads[workload.tenant_id] = workload
def register_on_notify(self, fn):
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
"""
Add some extra work during a notification, like sleeping to slow things down, or
logging what was notified.
@@ -30,7 +34,7 @@ class ComputeReconfigure:
@pytest.fixture(scope="function")
def compute_reconfigure_listener(make_httpserver):
def compute_reconfigure_listener(make_httpserver: HTTPServer):
"""
This fixture exposes an HTTP listener for the storage controller to submit
compute notifications to us, instead of updating neon_local endpoints itself.
@@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver):
# accept a healthy rate of calls into notify-attach.
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def handler(request: Request):
def handler(request: Request) -> Response:
assert request.json is not None
body: dict[str, Any] = request.json
log.info(f"notify-attach request: {body}")

View File

@@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels
from fixtures.log_helper import log
if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import Any
"""
The plugin reruns flaky tests.
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
if TYPE_CHECKING:
from collections.abc import Iterator
from fixtures.port_distributor import PortDistributor
# TODO: mypy fails with:
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
# from fixtures.neon_fixtures import PortDistributor
@@ -17,7 +24,7 @@ def httpserver_ssl_context():
@pytest.fixture(scope="function")
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
host, port = httpserver_listen_address
if not host:
host = HTTPServer.DEFAULT_LISTEN_HOST
@@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
@pytest.fixture(scope="function")
def httpserver(make_httpserver):
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
server = make_httpserver
yield server
server.clear()
@pytest.fixture(scope="function")
def httpserver_listen_address(port_distributor) -> tuple[str, int]:
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
port = port_distributor.get_port()
return ("localhost", port)

View File

@@ -31,7 +31,7 @@ LOGGING = {
}
def getLogger(name="root") -> logging.Logger:
def getLogger(name: str = "root") -> logging.Logger:
"""Method to get logger for tests.
Should be used to get correctly initialized logger."""

View File

@@ -22,7 +22,7 @@ class Metrics:
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
filter = filter or {}
res = []
res: list[Sample] = []
for sample in self.metrics[name]:
try:
@@ -59,7 +59,7 @@ class MetricsGetter:
return results[0].value
def get_metrics_values(
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
) -> dict[str, float]:
"""
When fetching multiple named metrics, it is more efficient to use this

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast
import requests
if TYPE_CHECKING:
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional
from fixtures.pg_version import PgVersion
@@ -25,9 +25,7 @@ class NeonAPI:
self.__neon_api_key = neon_api_key
self.__neon_api_base_url = neon_api_base_url.strip("/")
def __request(
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
) -> requests.Response:
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
@@ -187,8 +185,8 @@ class NeonAPI:
def get_connection_uri(
self,
project_id: str,
branch_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
branch_id: str | None = None,
endpoint_id: str | None = None,
database_name: str = "neondb",
role_name: str = "neondb_owner",
pooled: bool = True,
@@ -264,7 +262,7 @@ class NeonAPI:
class NeonApiEndpoint:
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]):
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
self.neon_api = neon_api
if project_id is None:
project = neon_api.create_project(pg_version)

View File

@@ -3657,7 +3657,7 @@ class Endpoint(PgProtocol, LogUtils):
config_lines: Optional[list[str]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
allow_multiple=False,
allow_multiple: bool = False,
basebackup_request_tries: Optional[int] = None,
) -> Endpoint:
"""
@@ -3998,7 +3998,7 @@ class Safekeeper(LogUtils):
def timeline_dir(self, tenant_id, timeline_id) -> Path:
return self.data_dir / str(tenant_id) / str(timeline_id)
# List partial uploaded segments of this safekeeper. Works only for
# list partial uploaded segments of this safekeeper. Works only for
# RemoteStorageKind.LOCAL_FS.
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
tline_path = (
@@ -4293,7 +4293,7 @@ def pytest_addoption(parser: Parser):
)
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
)

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
import psutil
if TYPE_CHECKING:
from collections.abc import Iterator
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
"""

View File

@@ -9,7 +9,12 @@ import toml
from _pytest.python import Metafunc
from fixtures.pg_version import PgVersion
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional

View File

@@ -2,9 +2,14 @@ from __future__ import annotations
import enum
import os
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from typing_extensions import override
if TYPE_CHECKING:
from typing import Optional
"""
This fixture is used to determine which version of Postgres to use for tests.
@@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum):
NOT_SET = "<-POSTRGRES VERSION IS NOT SET->"
# Make it less confusing in logs
@override
def __repr__(self) -> str:
return f"'{self.value}'"
# Make this explicit for Python 3.11 compatibility, which changes the behavior of enums
@override
def __str__(self) -> str:
return self.value
@@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum):
return f"v{self.value}"
@classmethod
def _missing_(cls, value) -> Optional[PgVersion]:
@override
def _missing_(cls, value: object) -> Optional[PgVersion]:
known_values = {v.value for _, v in cls.__members__.items()}
# Allow passing version as a string with "v" prefix (e.g. "v14")

View File

@@ -59,10 +59,7 @@ class PortDistributor:
if isinstance(value, int):
return self._replace_port_int(value)
if isinstance(value, str):
return self._replace_port_str(value)
raise TypeError(f"unsupported type {type(value)} of {value=}")
return self._replace_port_str(value)
def _replace_port_int(self, value: int) -> int:
known_port = self.port_map.get(value)
@@ -75,7 +72,7 @@ class PortDistributor:
# Use regex to find port in a string
# urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432"
# See https://bugs.python.org/issue27657
ports = re.findall(r":(\d+)(?:/|$)", value)
ports: list[str] = re.findall(r":(\d+)(?:/|$)", value)
assert len(ports) == 1, f"can't find port in {value}"
port_int = int(ports[0])

View File

@@ -13,6 +13,7 @@ import boto3
import toml
from moto.server import ThreadedMotoServer
from mypy_boto3_s3 import S3Client
from typing_extensions import override
from fixtures.common_types import TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log
@@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum):
EXTENSIONS = "ext"
SAFEKEEPER = "safekeeper"
@override
def __str__(self) -> str:
return self.value
@@ -81,11 +83,13 @@ class LocalFsStorage:
def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
return self.tenant_path(tenant_id) / "timelines" / str(timeline_id)
def timeline_latest_generation(self, tenant_id, timeline_id):
def timeline_latest_generation(
self, tenant_id: TenantId, timeline_id: TimelineId
) -> Optional[int]:
timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id))
index_parts = [f for f in timeline_files if f.startswith("index_part")]
def parse_gen(filename):
def parse_gen(filename: str) -> Optional[int]:
log.info(f"parsing index_part '{filename}'")
parts = filename.split("-")
if len(parts) == 2:
@@ -93,7 +97,7 @@ class LocalFsStorage:
else:
return None
generations = sorted([parse_gen(f) for f in index_parts])
generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore
if len(generations) == 0:
raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}")
return generations[-1]
@@ -122,14 +126,14 @@ class LocalFsStorage:
filename = f"{local_name}-{generation:08x}"
return self.timeline_path(tenant_id, timeline_id) / filename
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId):
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any:
with self.index_path(tenant_id, timeline_id).open("r") as f:
return json.load(f)
def heatmap_path(self, tenant_id: TenantId) -> Path:
return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME
def heatmap_content(self, tenant_id):
def heatmap_content(self, tenant_id: TenantId) -> Any:
with self.heatmap_path(tenant_id).open("r") as f:
return json.load(f)
@@ -297,7 +301,7 @@ class S3Storage:
def heatmap_key(self, tenant_id: TenantId) -> str:
return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}"
def heatmap_content(self, tenant_id: TenantId):
def heatmap_content(self, tenant_id: TenantId) -> Any:
r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id))
return json.loads(r["Body"].read().decode("utf-8"))
@@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum):
def configure(
self,
repo_dir: Path,
mock_s3_server,
mock_s3_server: MockS3Server,
run_id: str,
test_name: str,
user: RemoteStorageUser,
@@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind:
def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_dict()
# serialize as toml inline table
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_inline_table()

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from typing import Any, Optional
from typing import TYPE_CHECKING
import pytest
import requests
@@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Optional
class StorageControllerProxy:
def __init__(self, server: HTTPServer):
@@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response:
@pytest.fixture(scope="function")
def storage_controller_proxy(make_httpserver):
def storage_controller_proxy(make_httpserver: HTTPServer):
"""
Proxies requests into the storage controller to the currently
selected storage controller instance via `StorageControllerProxy.route_to`.
@@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver):
log.info(f"Storage controller proxy listening on {self.listen}")
def handler(request: Request):
def handler(request: Request) -> Response:
if self.route_to is None:
log.info(f"Storage controller proxy has no routing configured for {request.url}")
return Response("Routing not configured", status=503)

View File

@@ -18,6 +18,7 @@ from urllib.parse import urlencode
import allure
import zstandard
from psycopg2.extensions import cursor
from typing_extensions import override
from fixtures.log_helper import log
from fixtures.pageserver.common_types import (
@@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import (
)
if TYPE_CHECKING:
from typing import (
IO,
Optional,
Union,
)
from collections.abc import Iterable
from typing import IO, Optional
from fixtures.common_types import TimelineId
from fixtures.neon_fixtures import PgBin
from fixtures.common_types import TimelineId
WaitUntilRet = TypeVar("WaitUntilRet")
Fn = TypeVar("Fn", bound=Callable[..., Any])
@@ -42,12 +43,12 @@ def subprocess_capture(
capture_dir: Path,
cmd: list[str],
*,
check=False,
echo_stderr=False,
echo_stdout=False,
capture_stdout=False,
timeout=None,
with_command_header=True,
check: bool = False,
echo_stderr: bool = False,
echo_stdout: bool = False,
capture_stdout: bool = False,
timeout: Optional[float] = None,
with_command_header: bool = True,
**popen_kwargs: Any,
) -> tuple[str, Optional[str], int]:
"""Run a process and bifurcate its output to files and the `log` logger
@@ -84,6 +85,7 @@ def subprocess_capture(
self.capture = capture
self.captured = ""
@override
def run(self):
first = with_command_header
for line in self.in_file:
@@ -165,10 +167,10 @@ def global_counter() -> int:
def print_gc_result(row: dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
row
)
(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
).format_map(row)
)
@@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int:
return round(0.06689 * size_mb - 0.5)
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
)
@@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
"""Add links to server logs in Grafana to Allure report"""
links = {}
links: dict[str, str] = {}
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
endpoint_id, region_id, _ = host.split(".", 2)
@@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
def start_in_background(
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
) -> subprocess.Popen[bytes]:
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
@@ -376,14 +378,11 @@ def start_in_background(
return spawned_process
WaitUntilRet = TypeVar("WaitUntilRet")
def wait_until(
number_of_iterations: int,
interval: float,
func: Callable[[], WaitUntilRet],
show_intermediate_error=False,
show_intermediate_error: bool = False,
) -> WaitUntilRet:
"""
Wait until 'func' returns successfully, without exception. Returns the
@@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float:
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
error_or_warn = re.compile(r"\s(ERROR|WARN)")
errors = []
errors: list[tuple[int, str]] = []
for lineno, line in enumerate(input, start=1):
if len(line) == 0:
continue
@@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
return errors
def assert_no_errors(log_file, service, allowed_errors):
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
if not log_file.exists():
log.warning(f"Skipping {service} log check: {log_file} does not exist")
return
@@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum):
V2 = "v2"
CrossValidation = "cross-validation"
@override
def __repr__(self) -> str:
return f"'aux-{self.value}'"
@override
def __str__(self) -> str:
return f"'aux-{self.value}'"
@@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
"""
started_at = time.time()
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
assert reader is not None
digest = sha256(usedforsecurity=False)
while True:
@@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
right_list
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
mismatching = set()
mismatching: set[str] = set()
for left_tuple, right_tuple in zip(left_list, right_list):
left_path, left_hash = left_tuple
@@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread):
Simple Thread wrapper with join() propagating the possible exception in the thread.
"""
@override
def run(self):
self.exc = None
try:
@@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread):
except BaseException as e:
self.exc = e
def join(self, timeout=None):
@override
def join(self, timeout: Optional[float] = None) -> Any:
super().join(timeout)
if self.exc:
raise self.exc

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import threading
from typing import Any, Optional
from typing import TYPE_CHECKING
from fixtures.common_types import TenantId, TimelineId
from fixtures.log_helper import log
@@ -14,6 +14,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pageserver.utils import wait_for_last_record_lsn
if TYPE_CHECKING:
from typing import Any, Optional
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
ENDPOINT_LOCK = threading.Lock()
@@ -100,7 +103,7 @@ class Workload:
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
)
def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True):
def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True):
endpoint = self.endpoint(pageserver_id)
start = self.expect_rows
end = start + n - 1
@@ -121,7 +124,9 @@ class Workload:
else:
return False
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
def churn_rows(
self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True
):
assert self.expect_rows >= n
max_iters = 10

View File

@@ -4,7 +4,7 @@ import enum
import json
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.log_helper import log
@@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
AGGRESIVE_COMPACTION_TENANT_CONF = {
# Disable gc and compaction. The test runs compaction manually.
"gc_period": "0s",

View File

@@ -15,7 +15,7 @@ import enum
import os
import re
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TimelineId
@@ -40,6 +40,10 @@ from fixtures.remote_storage import (
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
# A tenant configuration that is convenient for generating uploads and deletions
# without a large amount of postgres traffic.
TENANT_CONF = {

View File

@@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
from pytest_httpserver import HTTPServer
from typing_extensions import override
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
@@ -954,6 +955,7 @@ class PageserverFailpoint(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.allowed_errors.extend(
@@ -961,19 +963,23 @@ class PageserverFailpoint(Failure):
)
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
if self._mitigate:
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"})
@override
def expect_available(self):
return True
@override
def can_mitigate(self):
return self._mitigate
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure):
self.pageserver_id = None
self.action = action
@override
def apply(self, env: NeonEnv):
env.storage_controller.configure_failpoints((self.failpoint, self.action))
@override
def clear(self, env: NeonEnv):
if "panic" in self.action:
log.info("Restarting storage controller after panic")
@@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure):
else:
env.storage_controller.configure_failpoints((self.failpoint, "off"))
@override
def expect_available(self):
# Controller panics _do_ leave pageservers available, but our test code relies
# on using the locate API to update configurations in Workload, so we must skip
# these actions when the controller has been panicked.
return "panic" not in self.action
@override
def can_mitigate(self):
return False
def fails_forward(self, env):
@override
def fails_forward(self, env: NeonEnv):
# Edge case: the very last failpoint that simulates a DB connection error, where
# the abort path will fail-forward and result in a complete split.
fail_forward = self.failpoint == "shard-split-post-complete"
@@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure):
return fail_forward
@override
def expect_exception(self):
if "panic" in self.action:
return requests.exceptions.ConnectionError
@@ -1029,18 +1041,22 @@ class NodeKill(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.stop(immediate=True)
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.start()
@override
def expect_available(self):
return False
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -1059,21 +1075,26 @@ class CompositeFailure(Failure):
self.pageserver_id = f.pageserver_id
break
@override
def apply(self, env: NeonEnv):
for f in self.failures:
f.apply(env)
def clear(self, env):
@override
def clear(self, env: NeonEnv):
for f in self.failures:
f.clear(env)
@override
def expect_available(self):
return all(f.expect_available() for f in self.failures)
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
for f in self.failures:
f.mitigate(env)
@override
def expect_exception(self):
expect = set(f.expect_exception() for f in self.failures)
@@ -1211,7 +1232,7 @@ def test_sharding_split_failures(
assert attached_count == initial_shard_count
def assert_split_done(exclude_ps_id=None) -> None:
def assert_split_done(exclude_ps_id: Optional[int] = None) -> None:
secondary_count = 0
attached_count = 0
for ps in env.pageservers:

View File

@@ -1038,7 +1038,7 @@ def test_storage_controller_tenant_deletion(
)
# Break the compute hook: we are checking that deletion does not depend on the compute hook being available
def break_hook():
def break_hook(_body: Any):
raise RuntimeError("Unexpected call to compute hook")
compute_reconfigure_listener.register_on_notify(break_hook)

View File

@@ -6,7 +6,7 @@ import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TenantShardId, TimelineId
@@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
@pytest.mark.parametrize("shard_count", [None, 4])
def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):