mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-27 08:09:58 +00:00
Improve some typing in test_runner
Fixes some types, adds some types, and adds some override annotations. Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
@@ -6,6 +6,8 @@ from enum import Enum
|
||||
from functools import total_ordering
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -31,33 +33,36 @@ class Lsn:
|
||||
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
|
||||
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
"""Convert lsn from int to standard hex notation."""
|
||||
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'Lsn("{str(self)}")'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.lsn_int
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int < other.lsn_int
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
raise NotImplementedError
|
||||
return self.lsn_int > other.lsn_int
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int == other.lsn_int
|
||||
|
||||
# Returns the difference between two Lsns, in bytes
|
||||
def __sub__(self, other: Any) -> int:
|
||||
def __sub__(self, other: object) -> int:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int - other.lsn_int
|
||||
@@ -70,6 +75,7 @@ class Lsn:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.lsn_int)
|
||||
|
||||
@@ -116,19 +122,22 @@ class Id:
|
||||
self.id = bytearray.fromhex(x)
|
||||
assert len(self.id) == 16
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self.id < other.id
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self.id))
|
||||
|
||||
@@ -139,25 +148,31 @@ class Id:
|
||||
|
||||
|
||||
class TenantId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'`TenantId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
|
||||
class NodeId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'`NodeId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
|
||||
class TimelineId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'TimelineId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
@@ -187,7 +202,7 @@ class TenantShardId:
|
||||
assert self.shard_number < self.shard_count or self.shard_count == 0
|
||||
|
||||
@classmethod
|
||||
def parse(cls: type[TTenantShardId], input) -> TTenantShardId:
|
||||
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
|
||||
if len(input) == 32:
|
||||
return cls(
|
||||
tenant_id=TenantId(input),
|
||||
@@ -203,6 +218,7 @@ class TenantShardId:
|
||||
else:
|
||||
raise ValueError(f"Invalid TenantShardId '{input}'")
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
if self.shard_count > 0:
|
||||
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
|
||||
@@ -210,22 +226,25 @@ class TenantShardId:
|
||||
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
|
||||
return str(self.tenant_id)
|
||||
|
||||
@override
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def _tuple(self) -> tuple[TenantId, int, int]:
|
||||
return (self.tenant_id, self.shard_number, self.shard_count)
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self._tuple() < other._tuple()
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self._tuple() == other._tuple()
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._tuple())
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager
|
||||
|
||||
# Type-related stuff
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
|
||||
from fixtures.log_helper import log
|
||||
@@ -24,6 +26,9 @@ from fixtures.neon_fixtures import (
|
||||
)
|
||||
from fixtures.pg_stats import PgStatTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
class PgCompare(ABC):
|
||||
"""Common interface of all postgres implementations, useful for benchmarks.
|
||||
@@ -65,12 +70,12 @@ class PgCompare(ABC):
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_pageserver_writes(self, out_name):
|
||||
def record_pageserver_writes(self, out_name: str):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_duration(self, out_name):
|
||||
def record_duration(self, out_name: str):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@@ -122,28 +127,34 @@ class NeonCompare(PgCompare):
|
||||
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg_bin
|
||||
|
||||
@override
|
||||
def flush(self, compact: bool = True, gc: bool = True):
|
||||
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
|
||||
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
|
||||
if gc:
|
||||
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
self.pageserver_http_client.timeline_compact(
|
||||
self.tenant, self.timeline, wait_until_uploaded=True
|
||||
)
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
self.zenbenchmark.record(
|
||||
"peak_mem",
|
||||
@@ -152,6 +163,7 @@ class NeonCompare(PgCompare):
|
||||
report=MetricReport.LOWER_IS_BETTER,
|
||||
)
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
timeline_size = self.zenbenchmark.get_timeline_size(
|
||||
self.env.repo_dir, self.tenant, self.timeline
|
||||
@@ -185,9 +197,11 @@ class NeonCompare(PgCompare):
|
||||
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
|
||||
)
|
||||
|
||||
@override
|
||||
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
@@ -211,26 +225,33 @@ class VanillaCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> VanillaPostgres:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
@override
|
||||
def flush(self, compact: bool = False, gc: bool = False):
|
||||
self.cur.execute("checkpoint")
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
pass # TODO find something
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
data_size = self.pg.get_subdir_size(Path("base"))
|
||||
self.zenbenchmark.record(
|
||||
@@ -245,6 +266,7 @@ class VanillaCompare(PgCompare):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
@@ -261,28 +283,35 @@ class RemoteCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
@override
|
||||
def flush(self, compact: bool = False, gc: bool = False):
|
||||
# TODO: flush the remote pageserver
|
||||
pass
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
# TODO: get memory usage from remote pageserver
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
# TODO: get storage size from remote pageserver
|
||||
pass
|
||||
@@ -291,6 +320,7 @@ class RemoteCompare(PgCompare):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
from fixtures.common_types import TenantId
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
class ComputeReconfigure:
|
||||
def __init__(self, server):
|
||||
def __init__(self, server: HTTPServer):
|
||||
self.server = server
|
||||
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
|
||||
self.workloads = {}
|
||||
self.on_notify = None
|
||||
self.workloads: dict[TenantId, Any] = {}
|
||||
self.on_notify: Optional[Callable[[Any], None]] = None
|
||||
|
||||
def register_workload(self, workload):
|
||||
def register_workload(self, workload: Any):
|
||||
self.workloads[workload.tenant_id] = workload
|
||||
|
||||
def register_on_notify(self, fn):
|
||||
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
|
||||
"""
|
||||
Add some extra work during a notification, like sleeping to slow things down, or
|
||||
logging what was notified.
|
||||
@@ -30,7 +34,7 @@ class ComputeReconfigure:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def compute_reconfigure_listener(make_httpserver):
|
||||
def compute_reconfigure_listener(make_httpserver: HTTPServer):
|
||||
"""
|
||||
This fixture exposes an HTTP listener for the storage controller to submit
|
||||
compute notifications to us, instead of updating neon_local endpoints itself.
|
||||
@@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver):
|
||||
# accept a healthy rate of calls into notify-attach.
|
||||
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def handler(request: Request):
|
||||
def handler(request: Request) -> Response:
|
||||
assert request.json is not None
|
||||
body: dict[str, Any] = request.json
|
||||
log.info(f"notify-attach request: {body}")
|
||||
|
||||
@@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
"""
|
||||
The plugin reruns flaky tests.
|
||||
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_httpserver import HTTPServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
|
||||
# TODO: mypy fails with:
|
||||
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
|
||||
# from fixtures.neon_fixtures import PortDistributor
|
||||
@@ -17,7 +24,7 @@ def httpserver_ssl_context():
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
|
||||
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
|
||||
host, port = httpserver_listen_address
|
||||
if not host:
|
||||
host = HTTPServer.DEFAULT_LISTEN_HOST
|
||||
@@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def httpserver(make_httpserver):
|
||||
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
|
||||
server = make_httpserver
|
||||
yield server
|
||||
server.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def httpserver_listen_address(port_distributor) -> tuple[str, int]:
|
||||
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
|
||||
port = port_distributor.get_port()
|
||||
return ("localhost", port)
|
||||
|
||||
@@ -31,7 +31,7 @@ LOGGING = {
|
||||
}
|
||||
|
||||
|
||||
def getLogger(name="root") -> logging.Logger:
|
||||
def getLogger(name: str = "root") -> logging.Logger:
|
||||
"""Method to get logger for tests.
|
||||
|
||||
Should be used to get correctly initialized logger."""
|
||||
|
||||
@@ -22,7 +22,7 @@ class Metrics:
|
||||
|
||||
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
|
||||
filter = filter or {}
|
||||
res = []
|
||||
res: list[Sample] = []
|
||||
|
||||
for sample in self.metrics[name]:
|
||||
try:
|
||||
@@ -59,7 +59,7 @@ class MetricsGetter:
|
||||
return results[0].value
|
||||
|
||||
def get_metrics_values(
|
||||
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False
|
||||
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
When fetching multiple named metrics, it is more efficient to use this
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast
|
||||
import requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from fixtures.pg_version import PgVersion
|
||||
|
||||
@@ -25,9 +25,7 @@ class NeonAPI:
|
||||
self.__neon_api_key = neon_api_key
|
||||
self.__neon_api_base_url = neon_api_base_url.strip("/")
|
||||
|
||||
def __request(
|
||||
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
|
||||
if "headers" not in kwargs:
|
||||
kwargs["headers"] = {}
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
|
||||
@@ -187,8 +185,8 @@ class NeonAPI:
|
||||
def get_connection_uri(
|
||||
self,
|
||||
project_id: str,
|
||||
branch_id: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
branch_id: str | None = None,
|
||||
endpoint_id: str | None = None,
|
||||
database_name: str = "neondb",
|
||||
role_name: str = "neondb_owner",
|
||||
pooled: bool = True,
|
||||
@@ -264,7 +262,7 @@ class NeonAPI:
|
||||
|
||||
|
||||
class NeonApiEndpoint:
|
||||
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]):
|
||||
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
|
||||
self.neon_api = neon_api
|
||||
if project_id is None:
|
||||
project = neon_api.create_project(pg_version)
|
||||
|
||||
@@ -3657,7 +3657,7 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
config_lines: Optional[list[str]] = None,
|
||||
remote_ext_config: Optional[str] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
allow_multiple: bool = False,
|
||||
basebackup_request_tries: Optional[int] = None,
|
||||
) -> Endpoint:
|
||||
"""
|
||||
@@ -3998,7 +3998,7 @@ class Safekeeper(LogUtils):
|
||||
def timeline_dir(self, tenant_id, timeline_id) -> Path:
|
||||
return self.data_dir / str(tenant_id) / str(timeline_id)
|
||||
|
||||
# List partial uploaded segments of this safekeeper. Works only for
|
||||
# list partial uploaded segments of this safekeeper. Works only for
|
||||
# RemoteStorageKind.LOCAL_FS.
|
||||
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
|
||||
tline_path = (
|
||||
@@ -4293,7 +4293,7 @@ def pytest_addoption(parser: Parser):
|
||||
)
|
||||
|
||||
|
||||
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
|
||||
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import psutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
|
||||
"""
|
||||
|
||||
@@ -9,7 +9,12 @@ import toml
|
||||
from _pytest.python import Metafunc
|
||||
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.utils import AuxFileStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
from fixtures.utils import AuxFileStore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
"""
|
||||
This fixture is used to determine which version of Postgres to use for tests.
|
||||
@@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum):
|
||||
NOT_SET = "<-POSTRGRES VERSION IS NOT SET->"
|
||||
|
||||
# Make it less confusing in logs
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"'{self.value}'"
|
||||
|
||||
# Make this explicit for Python 3.11 compatibility, which changes the behavior of enums
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum):
|
||||
return f"v{self.value}"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value) -> Optional[PgVersion]:
|
||||
@override
|
||||
def _missing_(cls, value: object) -> Optional[PgVersion]:
|
||||
known_values = {v.value for _, v in cls.__members__.items()}
|
||||
|
||||
# Allow passing version as a string with "v" prefix (e.g. "v14")
|
||||
|
||||
@@ -59,10 +59,7 @@ class PortDistributor:
|
||||
if isinstance(value, int):
|
||||
return self._replace_port_int(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
return self._replace_port_str(value)
|
||||
|
||||
raise TypeError(f"unsupported type {type(value)} of {value=}")
|
||||
return self._replace_port_str(value)
|
||||
|
||||
def _replace_port_int(self, value: int) -> int:
|
||||
known_port = self.port_map.get(value)
|
||||
@@ -75,7 +72,7 @@ class PortDistributor:
|
||||
# Use regex to find port in a string
|
||||
# urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432"
|
||||
# See https://bugs.python.org/issue27657
|
||||
ports = re.findall(r":(\d+)(?:/|$)", value)
|
||||
ports: list[str] = re.findall(r":(\d+)(?:/|$)", value)
|
||||
assert len(ports) == 1, f"can't find port in {value}"
|
||||
port_int = int(ports[0])
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import boto3
|
||||
import toml
|
||||
from moto.server import ThreadedMotoServer
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
@@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum):
|
||||
EXTENSIONS = "ext"
|
||||
SAFEKEEPER = "safekeeper"
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -81,11 +83,13 @@ class LocalFsStorage:
|
||||
def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
|
||||
return self.tenant_path(tenant_id) / "timelines" / str(timeline_id)
|
||||
|
||||
def timeline_latest_generation(self, tenant_id, timeline_id):
|
||||
def timeline_latest_generation(
|
||||
self, tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> Optional[int]:
|
||||
timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id))
|
||||
index_parts = [f for f in timeline_files if f.startswith("index_part")]
|
||||
|
||||
def parse_gen(filename):
|
||||
def parse_gen(filename: str) -> Optional[int]:
|
||||
log.info(f"parsing index_part '{filename}'")
|
||||
parts = filename.split("-")
|
||||
if len(parts) == 2:
|
||||
@@ -93,7 +97,7 @@ class LocalFsStorage:
|
||||
else:
|
||||
return None
|
||||
|
||||
generations = sorted([parse_gen(f) for f in index_parts])
|
||||
generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore
|
||||
if len(generations) == 0:
|
||||
raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}")
|
||||
return generations[-1]
|
||||
@@ -122,14 +126,14 @@ class LocalFsStorage:
|
||||
filename = f"{local_name}-{generation:08x}"
|
||||
return self.timeline_path(tenant_id, timeline_id) / filename
|
||||
|
||||
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId):
|
||||
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any:
|
||||
with self.index_path(tenant_id, timeline_id).open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def heatmap_path(self, tenant_id: TenantId) -> Path:
|
||||
return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME
|
||||
|
||||
def heatmap_content(self, tenant_id):
|
||||
def heatmap_content(self, tenant_id: TenantId) -> Any:
|
||||
with self.heatmap_path(tenant_id).open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@@ -297,7 +301,7 @@ class S3Storage:
|
||||
def heatmap_key(self, tenant_id: TenantId) -> str:
|
||||
return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}"
|
||||
|
||||
def heatmap_content(self, tenant_id: TenantId):
|
||||
def heatmap_content(self, tenant_id: TenantId) -> Any:
|
||||
r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id))
|
||||
return json.loads(r["Body"].read().decode("utf-8"))
|
||||
|
||||
@@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum):
|
||||
def configure(
|
||||
self,
|
||||
repo_dir: Path,
|
||||
mock_s3_server,
|
||||
mock_s3_server: MockS3Server,
|
||||
run_id: str,
|
||||
test_name: str,
|
||||
user: RemoteStorageUser,
|
||||
@@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind:
|
||||
|
||||
|
||||
def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]:
|
||||
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
|
||||
raise Exception("invalid remote storage type")
|
||||
|
||||
return remote_storage.to_toml_dict()
|
||||
|
||||
|
||||
# serialize as toml inline table
|
||||
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
|
||||
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
|
||||
raise Exception("invalid remote storage type")
|
||||
|
||||
return remote_storage.to_toml_inline_table()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response
|
||||
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class StorageControllerProxy:
|
||||
def __init__(self, server: HTTPServer):
|
||||
@@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def storage_controller_proxy(make_httpserver):
|
||||
def storage_controller_proxy(make_httpserver: HTTPServer):
|
||||
"""
|
||||
Proxies requests into the storage controller to the currently
|
||||
selected storage controller instance via `StorageControllerProxy.route_to`.
|
||||
@@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver):
|
||||
|
||||
log.info(f"Storage controller proxy listening on {self.listen}")
|
||||
|
||||
def handler(request: Request):
|
||||
def handler(request: Request) -> Response:
|
||||
if self.route_to is None:
|
||||
log.info(f"Storage controller proxy has no routing configured for {request.url}")
|
||||
return Response("Routing not configured", status=503)
|
||||
|
||||
@@ -18,6 +18,7 @@ from urllib.parse import urlencode
|
||||
import allure
|
||||
import zstandard
|
||||
from psycopg2.extensions import cursor
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.pageserver.common_types import (
|
||||
@@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
IO,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from typing import IO, Optional
|
||||
|
||||
from fixtures.common_types import TimelineId
|
||||
from fixtures.neon_fixtures import PgBin
|
||||
from fixtures.common_types import TimelineId
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
Fn = TypeVar("Fn", bound=Callable[..., Any])
|
||||
|
||||
@@ -42,12 +43,12 @@ def subprocess_capture(
|
||||
capture_dir: Path,
|
||||
cmd: list[str],
|
||||
*,
|
||||
check=False,
|
||||
echo_stderr=False,
|
||||
echo_stdout=False,
|
||||
capture_stdout=False,
|
||||
timeout=None,
|
||||
with_command_header=True,
|
||||
check: bool = False,
|
||||
echo_stderr: bool = False,
|
||||
echo_stdout: bool = False,
|
||||
capture_stdout: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
with_command_header: bool = True,
|
||||
**popen_kwargs: Any,
|
||||
) -> tuple[str, Optional[str], int]:
|
||||
"""Run a process and bifurcate its output to files and the `log` logger
|
||||
@@ -84,6 +85,7 @@ def subprocess_capture(
|
||||
self.capture = capture
|
||||
self.captured = ""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
first = with_command_header
|
||||
for line in self.in_file:
|
||||
@@ -165,10 +167,10 @@ def global_counter() -> int:
|
||||
def print_gc_result(row: dict[str, Any]):
|
||||
log.info("GC duration {elapsed} ms".format_map(row))
|
||||
log.info(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
|
||||
row
|
||||
)
|
||||
(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
|
||||
).format_map(row)
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int:
|
||||
return round(0.06689 * size_mb - 0.5)
|
||||
|
||||
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
|
||||
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
|
||||
)
|
||||
|
||||
@@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
|
||||
|
||||
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
|
||||
"""Add links to server logs in Grafana to Allure report"""
|
||||
links = {}
|
||||
links: dict[str, str] = {}
|
||||
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
|
||||
endpoint_id, region_id, _ = host.split(".", 2)
|
||||
|
||||
@@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
|
||||
|
||||
|
||||
def start_in_background(
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
|
||||
) -> subprocess.Popen[bytes]:
|
||||
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
|
||||
|
||||
@@ -376,14 +378,11 @@ def start_in_background(
|
||||
return spawned_process
|
||||
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
def wait_until(
|
||||
number_of_iterations: int,
|
||||
interval: float,
|
||||
func: Callable[[], WaitUntilRet],
|
||||
show_intermediate_error=False,
|
||||
show_intermediate_error: bool = False,
|
||||
) -> WaitUntilRet:
|
||||
"""
|
||||
Wait until 'func' returns successfully, without exception. Returns the
|
||||
@@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float:
|
||||
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
|
||||
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
|
||||
error_or_warn = re.compile(r"\s(ERROR|WARN)")
|
||||
errors = []
|
||||
errors: list[tuple[int, str]] = []
|
||||
for lineno, line in enumerate(input, start=1):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
@@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
|
||||
return errors
|
||||
|
||||
|
||||
def assert_no_errors(log_file, service, allowed_errors):
|
||||
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
|
||||
if not log_file.exists():
|
||||
log.warning(f"Skipping {service} log check: {log_file} does not exist")
|
||||
return
|
||||
@@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum):
|
||||
V2 = "v2"
|
||||
CrossValidation = "cross-validation"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
"""
|
||||
started_at = time.time()
|
||||
|
||||
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
|
||||
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
|
||||
assert reader is not None
|
||||
digest = sha256(usedforsecurity=False)
|
||||
while True:
|
||||
@@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
right_list
|
||||
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
|
||||
|
||||
mismatching = set()
|
||||
mismatching: set[str] = set()
|
||||
|
||||
for left_tuple, right_tuple in zip(left_list, right_list):
|
||||
left_path, left_hash = left_tuple
|
||||
@@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread):
|
||||
Simple Thread wrapper with join() propagating the possible exception in the thread.
|
||||
"""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.exc = None
|
||||
try:
|
||||
@@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread):
|
||||
except BaseException as e:
|
||||
self.exc = e
|
||||
|
||||
def join(self, timeout=None):
|
||||
@override
|
||||
def join(self, timeout: Optional[float] = None) -> Any:
|
||||
super().join(timeout)
|
||||
if self.exc:
|
||||
raise self.exc
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
@@ -14,6 +14,9 @@ from fixtures.neon_fixtures import (
|
||||
)
|
||||
from fixtures.pageserver.utils import wait_for_last_record_lsn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
|
||||
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
|
||||
ENDPOINT_LOCK = threading.Lock()
|
||||
@@ -100,7 +103,7 @@ class Workload:
|
||||
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
|
||||
)
|
||||
|
||||
def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True):
|
||||
def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True):
|
||||
endpoint = self.endpoint(pageserver_id)
|
||||
start = self.expect_rows
|
||||
end = start + n - 1
|
||||
@@ -121,7 +124,9 @@ class Workload:
|
||||
else:
|
||||
return False
|
||||
|
||||
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
|
||||
def churn_rows(
|
||||
self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True
|
||||
):
|
||||
assert self.expect_rows >= n
|
||||
|
||||
max_iters = 10
|
||||
|
||||
@@ -4,7 +4,7 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
@@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
AGGRESIVE_COMPACTION_TENANT_CONF = {
|
||||
# Disable gc and compaction. The test runs compaction manually.
|
||||
"gc_period": "0s",
|
||||
|
||||
@@ -15,7 +15,7 @@ import enum
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
@@ -40,6 +40,10 @@ from fixtures.remote_storage import (
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# A tenant configuration that is convenient for generating uploads and deletions
|
||||
# without a large amount of postgres traffic.
|
||||
TENANT_CONF = {
|
||||
|
||||
@@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
from pytest_httpserver import HTTPServer
|
||||
from typing_extensions import override
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
@@ -954,6 +955,7 @@ class PageserverFailpoint(Failure):
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.allowed_errors.extend(
|
||||
@@ -961,19 +963,23 @@ class PageserverFailpoint(Failure):
|
||||
)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
|
||||
if self._mitigate:
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"})
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return True
|
||||
|
||||
@override
|
||||
def can_mitigate(self):
|
||||
return self._mitigate
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
@@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure):
|
||||
self.pageserver_id = None
|
||||
self.action = action
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
env.storage_controller.configure_failpoints((self.failpoint, self.action))
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
if "panic" in self.action:
|
||||
log.info("Restarting storage controller after panic")
|
||||
@@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure):
|
||||
else:
|
||||
env.storage_controller.configure_failpoints((self.failpoint, "off"))
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
# Controller panics _do_ leave pageservers available, but our test code relies
|
||||
# on using the locate API to update configurations in Workload, so we must skip
|
||||
# these actions when the controller has been panicked.
|
||||
return "panic" not in self.action
|
||||
|
||||
@override
|
||||
def can_mitigate(self):
|
||||
return False
|
||||
|
||||
def fails_forward(self, env):
|
||||
@override
|
||||
def fails_forward(self, env: NeonEnv):
|
||||
# Edge case: the very last failpoint that simulates a DB connection error, where
|
||||
# the abort path will fail-forward and result in a complete split.
|
||||
fail_forward = self.failpoint == "shard-split-post-complete"
|
||||
@@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure):
|
||||
|
||||
return fail_forward
|
||||
|
||||
@override
|
||||
def expect_exception(self):
|
||||
if "panic" in self.action:
|
||||
return requests.exceptions.ConnectionError
|
||||
@@ -1029,18 +1041,22 @@ class NodeKill(Failure):
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.stop(immediate=True)
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.start()
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return False
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
@@ -1059,21 +1075,26 @@ class CompositeFailure(Failure):
|
||||
self.pageserver_id = f.pageserver_id
|
||||
break
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.apply(env)
|
||||
|
||||
def clear(self, env):
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.clear(env)
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return all(f.expect_available() for f in self.failures)
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.mitigate(env)
|
||||
|
||||
@override
|
||||
def expect_exception(self):
|
||||
expect = set(f.expect_exception() for f in self.failures)
|
||||
|
||||
@@ -1211,7 +1232,7 @@ def test_sharding_split_failures(
|
||||
|
||||
assert attached_count == initial_shard_count
|
||||
|
||||
def assert_split_done(exclude_ps_id=None) -> None:
|
||||
def assert_split_done(exclude_ps_id: Optional[int] = None) -> None:
|
||||
secondary_count = 0
|
||||
attached_count = 0
|
||||
for ps in env.pageservers:
|
||||
|
||||
@@ -1038,7 +1038,7 @@ def test_storage_controller_tenant_deletion(
|
||||
)
|
||||
|
||||
# Break the compute hook: we are checking that deletion does not depend on the compute hook being available
|
||||
def break_hook():
|
||||
def break_hook(_body: Any):
|
||||
raise RuntimeError("Unexpected call to compute hook")
|
||||
|
||||
compute_reconfigure_listener.register_on_notify(break_hook)
|
||||
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
@@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shard_count", [None, 4])
|
||||
def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):
|
||||
|
||||
Reference in New Issue
Block a user