Stricter mypy linters for test_runner/fixtures/*

This commit is contained in:
Vadim Kharitonov
2022-11-08 10:51:34 +01:00
committed by Vadim Kharitonov
parent c4f9f1dc6d
commit f720dd735e
11 changed files with 255 additions and 201 deletions

View File

@@ -11,39 +11,37 @@ from datetime import datetime
from pathlib import Path
# Type-related stuff
from typing import Iterator, Optional
from typing import Callable, ClassVar, Iterator, Optional
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter
from fixtures.neon_fixtures import NeonPageserver
from fixtures.types import TenantId, TimelineId
"""
This file contains fixtures for micro-benchmarks.
To use, declare the 'zenbenchmark' fixture in the test function. Run the
bencmark, and then record the result by calling zenbenchmark.record. For example:
To use, declare the `zenbenchmark` fixture in the test function. Run the
bencmark, and then record the result by calling `zenbenchmark.record`. For example:
import timeit
from fixtures.neon_fixtures import NeonEnv
def test_mybench(neon_simple_env: env, zenbenchmark):
# Initialize the test
...
# Run the test, timing how long it takes
with zenbenchmark.record_duration('test_query'):
cur.execute('SELECT test_query(...)')
# Record another measurement
zenbenchmark.record('speed_of_light', 300000, 'km/s')
>>> import timeit
>>> from fixtures.neon_fixtures import NeonEnv
>>> def test_mybench(neon_simple_env: NeonEnv, zenbenchmark):
... # Initialize the test
... ...
... # Run the test, timing how long it takes
... with zenbenchmark.record_duration('test_query'):
... cur.execute('SELECT test_query(...)')
... # Record another measurement
... zenbenchmark.record('speed_of_light', 300000, 'km/s')
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
inside `conftest.py`, and that makes it available to all tests.
You can measure multiple things in one test, and record each one with a separate
call to zenbenchmark. For example, you could time the bulk loading that happens
call to `zenbenchmark`. For example, you could time the bulk loading that happens
in the test initialization, or measure disk usage after the test query.
"""
@@ -117,7 +115,7 @@ class PgBenchRunResult:
# tps = 309.281539 (without initial connection time)
if line.startswith("tps = ") and (
"(excluding connections establishing)" in line
or "(without initial connection time)"
or "(without initial connection time)" in line
):
tps = float(line.split()[2])
@@ -137,6 +135,17 @@ class PgBenchRunResult:
@dataclasses.dataclass
class PgBenchInitResult:
REGEX: ClassVar[re.Pattern] = re.compile( # type: ignore[type-arg]
r"done in (\d+\.\d+) s "
r"\("
r"(?:drop tables (\d+\.\d+) s)?(?:, )?"
r"(?:create tables (\d+\.\d+) s)?(?:, )?"
r"(?:client-side generate (\d+\.\d+) s)?(?:, )?"
r"(?:vacuum (\d+\.\d+) s)?(?:, )?"
r"(?:primary keys (\d+\.\d+) s)?(?:, )?"
r"\)\."
)
total: float
drop_tables: Optional[float]
create_tables: Optional[float]
@@ -160,18 +169,7 @@ class PgBenchInitResult:
last_line = stderr.splitlines()[-1]
regex = re.compile(
r"done in (\d+\.\d+) s "
r"\("
r"(?:drop tables (\d+\.\d+) s)?(?:, )?"
r"(?:create tables (\d+\.\d+) s)?(?:, )?"
r"(?:client-side generate (\d+\.\d+) s)?(?:, )?"
r"(?:vacuum (\d+\.\d+) s)?(?:, )?"
r"(?:primary keys (\d+\.\d+) s)?(?:, )?"
r"\)\."
)
if (m := regex.match(last_line)) is not None:
if (m := cls.REGEX.match(last_line)) is not None:
total, drop_tables, create_tables, client_side_generate, vacuum, primary_keys = [
float(v) for v in m.groups() if v is not None
]
@@ -208,7 +206,7 @@ class NeonBenchmarker:
function by the zenbenchmark fixture
"""
def __init__(self, property_recorder):
def __init__(self, property_recorder: Callable[[str, object], None]):
# property recorder here is a pytest fixture provided by junitxml module
# https://docs.pytest.org/en/6.2.x/reference.html#pytest.junitxml.record_property
self.property_recorder = property_recorder
@@ -236,7 +234,7 @@ class NeonBenchmarker:
)
@contextmanager
def record_duration(self, metric_name: str):
def record_duration(self, metric_name: str) -> Iterator[None]:
"""
Record a duration. Usage:
@@ -337,21 +335,21 @@ class NeonBenchmarker:
f"{prefix}.{metric}", value, unit="s", report=MetricReport.LOWER_IS_BETTER
)
def get_io_writes(self, pageserver) -> int:
def get_io_writes(self, pageserver: NeonPageserver) -> int:
"""
Fetch the "cumulative # of bytes written" metric from the pageserver
"""
metric_name = r'libmetrics_disk_io_bytes_total{io_operation="write"}'
return self.get_int_counter_value(pageserver, metric_name)
def get_peak_mem(self, pageserver) -> int:
def get_peak_mem(self, pageserver: NeonPageserver) -> int:
"""
Fetch the "maxrss" metric from the pageserver
"""
metric_name = r"libmetrics_maxrss_kb"
return self.get_int_counter_value(pageserver, metric_name)
def get_int_counter_value(self, pageserver, metric_name) -> int:
def get_int_counter_value(self, pageserver: NeonPageserver, metric_name: str) -> int:
"""Fetch the value of given int counter from pageserver metrics."""
# TODO: If we start to collect more of the prometheus metrics in the
# performance test suite like this, we should refactor this to load and
@@ -365,7 +363,9 @@ class NeonBenchmarker:
assert matches, f"metric {metric_name} not found"
return int(round(float(matches.group(1))))
def get_timeline_size(self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId):
def get_timeline_size(
self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId
) -> int:
"""
Calculate the on-disk size of a timeline
"""
@@ -379,7 +379,9 @@ class NeonBenchmarker:
return totalbytes
@contextmanager
def record_pageserver_writes(self, pageserver, metric_name):
def record_pageserver_writes(
self, pageserver: NeonPageserver, metric_name: str
) -> Iterator[None]:
"""
Record bytes written by the pageserver during a test.
"""
@@ -396,7 +398,7 @@ class NeonBenchmarker:
@pytest.fixture(scope="function")
def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]:
def zenbenchmark(record_property: Callable[[str, object], None]) -> Iterator[NeonBenchmarker]:
"""
This is a python decorator for benchmark fixtures. It contains functions for
recording measurements, and prints them out at the end.
@@ -405,7 +407,7 @@ def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]:
yield benchmarker
def pytest_addoption(parser):
def pytest_addoption(parser: Parser):
parser.addoption(
"--out-dir",
dest="out_dir",
@@ -429,7 +431,9 @@ def get_out_path(target_dir: Path, revision: str) -> Path:
# Hook to print the results at the end
@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(terminalreporter: TerminalReporter, exitstatus: int, config: Config):
def pytest_terminal_summary(
terminalreporter: TerminalReporter, exitstatus: int, config: Config
) -> Iterator[None]:
yield
revision = os.getenv("GITHUB_SHA", "local")
platform = os.getenv("PLATFORM", "local")

View File

@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from typing import Dict, List
from typing import Dict, Iterator, List
import pytest
from _pytest.fixtures import FixtureRequest
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.neon_fixtures import NeonEnv, PgBin, PgProtocol, RemotePostgres, VanillaPostgres
from fixtures.pg_stats import PgStatTable
@@ -28,19 +29,20 @@ class PgCompare(ABC):
pass
@property
@abstractmethod
def zenbenchmark(self) -> NeonBenchmarker:
pass
@abstractmethod
def flush(self) -> None:
def flush(self):
pass
@abstractmethod
def report_peak_memory_use(self) -> None:
def report_peak_memory_use(self):
pass
@abstractmethod
def report_size(self) -> None:
def report_size(self):
pass
@contextmanager
@@ -54,7 +56,7 @@ class PgCompare(ABC):
pass
@contextmanager
def record_pg_stats(self, pg_stats: List[PgStatTable]):
def record_pg_stats(self, pg_stats: List[PgStatTable]) -> Iterator[None]:
init_data = self._retrieve_pg_stats(pg_stats)
yield
@@ -84,7 +86,11 @@ class NeonCompare(PgCompare):
"""PgCompare interface for the neon stack."""
def __init__(
self, zenbenchmark: NeonBenchmarker, neon_simple_env: NeonEnv, pg_bin: PgBin, branch_name
self,
zenbenchmark: NeonBenchmarker,
neon_simple_env: NeonEnv,
pg_bin: PgBin,
branch_name: str,
):
self.env = neon_simple_env
self._zenbenchmark = zenbenchmark
@@ -97,15 +103,15 @@ class NeonCompare(PgCompare):
self.timeline = self.pg.safe_psql("SHOW neon.timeline_id")[0][0]
@property
def pg(self):
def pg(self) -> PgProtocol:
return self._pg
@property
def zenbenchmark(self):
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
def pg_bin(self):
def pg_bin(self) -> PgBin:
return self._pg_bin
def flush(self):
@@ -114,7 +120,7 @@ class NeonCompare(PgCompare):
def compact(self):
self.pageserver_http_client.timeline_compact(self.env.initial_tenant, self.timeline)
def report_peak_memory_use(self) -> None:
def report_peak_memory_use(self):
self.zenbenchmark.record(
"peak_mem",
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
@@ -122,7 +128,7 @@ class NeonCompare(PgCompare):
report=MetricReport.LOWER_IS_BETTER,
)
def report_size(self) -> None:
def report_size(self):
timeline_size = self.zenbenchmark.get_timeline_size(
self.env.repo_dir, self.env.initial_tenant, self.timeline
)
@@ -144,17 +150,17 @@ class NeonCompare(PgCompare):
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
)
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
def record_duration(self, out_name):
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
class VanillaCompare(PgCompare):
"""PgCompare interface for vanilla postgres."""
def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres):
def __init__(self, zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres):
self._pg = vanilla_pg
self._zenbenchmark = zenbenchmark
vanilla_pg.configure(
@@ -170,24 +176,24 @@ class VanillaCompare(PgCompare):
self.cur = self.conn.cursor()
@property
def pg(self):
def pg(self) -> PgProtocol:
return self._pg
@property
def zenbenchmark(self):
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
def pg_bin(self):
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
def flush(self):
self.cur.execute("checkpoint")
def report_peak_memory_use(self) -> None:
def report_peak_memory_use(self):
pass # TODO find something
def report_size(self) -> None:
def report_size(self):
data_size = self.pg.get_subdir_size("base")
self.zenbenchmark.record(
"data_size", data_size / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER
@@ -198,17 +204,17 @@ class VanillaCompare(PgCompare):
)
@contextmanager
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
def record_duration(self, out_name):
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
class RemoteCompare(PgCompare):
"""PgCompare interface for a remote postgres instance."""
def __init__(self, zenbenchmark, remote_pg: RemotePostgres):
def __init__(self, zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres):
self._pg = remote_pg
self._zenbenchmark = zenbenchmark
@@ -217,55 +223,60 @@ class RemoteCompare(PgCompare):
self.cur = self.conn.cursor()
@property
def pg(self):
def pg(self) -> PgProtocol:
return self._pg
@property
def zenbenchmark(self):
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
def pg_bin(self):
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
def flush(self):
# TODO: flush the remote pageserver
pass
def report_peak_memory_use(self) -> None:
def report_peak_memory_use(self):
# TODO: get memory usage from remote pageserver
pass
def report_size(self) -> None:
def report_size(self):
# TODO: get storage size from remote pageserver
pass
@contextmanager
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
def record_duration(self, out_name):
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@pytest.fixture(scope="function")
def neon_compare(request, zenbenchmark, pg_bin, neon_simple_env) -> NeonCompare:
def neon_compare(
request: FixtureRequest,
zenbenchmark: NeonBenchmarker,
pg_bin: PgBin,
neon_simple_env: NeonEnv,
) -> NeonCompare:
branch_name = request.node.name
return NeonCompare(zenbenchmark, neon_simple_env, pg_bin, branch_name)
@pytest.fixture(scope="function")
def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
def vanilla_compare(zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres) -> VanillaCompare:
return VanillaCompare(zenbenchmark, vanilla_pg)
@pytest.fixture(scope="function")
def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare:
def remote_compare(zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres) -> RemoteCompare:
return RemoteCompare(zenbenchmark, remote_pg)
@pytest.fixture(params=["vanilla_compare", "neon_compare"], ids=["vanilla", "neon"])
def neon_with_baseline(request) -> PgCompare:
def neon_with_baseline(request: FixtureRequest) -> PgCompare:
"""Parameterized fixture that helps compare neon against vanilla postgres.
A test that uses this fixture turns into a parameterized test that runs against:
@@ -286,8 +297,6 @@ def neon_with_baseline(request) -> PgCompare:
implementation-specific logic is widely useful across multiple tests, it might
make sense to add methods to the PgCompare class.
"""
fixture = request.getfixturevalue(request.param)
if isinstance(fixture, PgCompare):
return fixture
else:
raise AssertionError(f"test error: fixture {request.param} is not PgCompare")
fixture = request.getfixturevalue(request.param) # type: ignore
assert isinstance(fixture, PgCompare), f"test error: fixture {fixture} is not PgCompare"
return fixture

View File

@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Optional, Tuple
from prometheus_client.parser import text_string_to_metric_families
from prometheus_client.samples import Sample
@@ -23,13 +23,13 @@ class Metrics:
pass
return res
def query_one(self, name: str, filter: Dict[str, str] = {}) -> Sample:
res = self.query_all(name, filter)
def query_one(self, name: str, filter: Optional[Dict[str, str]] = None) -> Sample:
res = self.query_all(name, filter or {})
assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}"
return res[0]
def parse_metrics(text: str, name: str = ""):
def parse_metrics(text: str, name: str = "") -> Metrics:
metrics = Metrics(name)
gen = text_string_to_metric_families(text)
for family in gen:
@@ -39,7 +39,7 @@ def parse_metrics(text: str, name: str = ""):
return metrics
PAGESERVER_PER_TENANT_METRICS = [
PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = (
"pageserver_current_logical_size",
"pageserver_current_physical_size",
"pageserver_getpage_reconstruct_seconds_bucket",
@@ -62,4 +62,4 @@ PAGESERVER_PER_TENANT_METRICS = [
"pageserver_wait_lsn_seconds_sum",
"pageserver_created_persistent_files_total",
"pageserver_written_persistent_bytes_total",
]
)

View File

@@ -19,7 +19,8 @@ from dataclasses import dataclass, field
from enum import Flag, auto
from functools import cached_property
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
from types import TracebackType
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
import asyncpg
import backoff # type: ignore
@@ -28,16 +29,18 @@ import jwt
import psycopg2
import pytest
import requests
from _pytest.config import Config
from _pytest.fixtures import FixtureRequest
from fixtures.log_helper import log
from fixtures.types import Lsn, TenantId, TimelineId
from fixtures.utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import cursor as PgCursor
from psycopg2.extensions import make_dsn, parse_dsn
from typing_extensions import Literal
from .utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture
"""
This file contains pytest fixtures. A fixture is a test resource that can be
summoned by placing its name in the test's arguments.
@@ -57,15 +60,15 @@ put directly-importable functions into utils.py or another separate file.
Env = Dict[str, str]
DEFAULT_OUTPUT_DIR = "test_output"
DEFAULT_BRANCH_NAME = "main"
DEFAULT_PG_VERSION_DEFAULT = "14"
DEFAULT_OUTPUT_DIR: str = "test_output"
DEFAULT_BRANCH_NAME: str = "main"
DEFAULT_PG_VERSION_DEFAULT: str = "14"
BASE_PORT = 15000
WORKER_PORT_NUM = 1000
BASE_PORT: int = 15000
WORKER_PORT_NUM: int = 1000
def pytest_configure(config):
def pytest_configure(config: Config):
"""
Check that we do not overflow available ports range.
"""
@@ -154,14 +157,14 @@ def versioned_pg_distrib_dir(pg_distrib_dir: Path, pg_version: str) -> Iterator[
if not psql_bin_path.exists():
raise Exception(f"psql not found at '{psql_bin_path}'")
else:
if not postgres_bin_path.exists:
if not postgres_bin_path.exists():
raise Exception(f"postgres not found at '{postgres_bin_path}'")
log.info(f"versioned_pg_distrib_dir is {versioned_dir}")
yield versioned_dir
def shareable_scope(fixture_name, config) -> Literal["session", "function"]:
def shareable_scope(fixture_name: str, config: Config) -> Literal["session", "function"]:
"""Return either session of function scope, depending on TEST_SHARED_FIXTURES envvar.
This function can be used as a scope like this:
@@ -173,7 +176,7 @@ def shareable_scope(fixture_name, config) -> Literal["session", "function"]:
@pytest.fixture(scope="session")
def worker_seq_no(worker_id: str):
def worker_seq_no(worker_id: str) -> int:
# worker_id is a pytest-xdist fixture
# it can be master or gw<number>
# parse it to always get a number
@@ -184,7 +187,7 @@ def worker_seq_no(worker_id: str):
@pytest.fixture(scope="session")
def worker_base_port(worker_seq_no: int):
def worker_base_port(worker_seq_no: int) -> int:
# so we divide ports in ranges of 100 ports
# so workers have disjoint set of ports for services
return BASE_PORT + worker_seq_no * WORKER_PORT_NUM
@@ -234,10 +237,9 @@ class PortDistributor:
for port in self.iterator:
if can_bind("localhost", port):
return port
else:
raise RuntimeError(
"port range configured for test is exhausted, consider enlarging the range"
)
raise RuntimeError(
"port range configured for test is exhausted, consider enlarging the range"
)
def replace_with_new_port(self, value: Union[int, str]) -> Union[int, str]:
"""
@@ -273,12 +275,14 @@ class PortDistributor:
@pytest.fixture(scope="session")
def port_distributor(worker_base_port):
def port_distributor(worker_base_port: int) -> PortDistributor:
return PortDistributor(base_port=worker_base_port, port_number=WORKER_PORT_NUM)
@pytest.fixture(scope="session")
def default_broker(request: Any, port_distributor: PortDistributor, top_output_dir: Path):
def default_broker(
request: FixtureRequest, port_distributor: PortDistributor, top_output_dir: Path
) -> Iterator[Etcd]:
client_port = port_distributor.get_port()
# multiple pytest sessions could get launched in parallel, get them different datadirs
etcd_datadir = get_test_output_dir(request, top_output_dir) / f"etcd_datadir_{client_port}"
@@ -293,12 +297,12 @@ def default_broker(request: Any, port_distributor: PortDistributor, top_output_d
@pytest.fixture(scope="session")
def run_id():
def run_id() -> Iterator[uuid.UUID]:
yield uuid.uuid4()
@pytest.fixture(scope="session")
def mock_s3_server(port_distributor: PortDistributor):
def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]:
mock_s3_server = MockS3Server(port_distributor.get_port())
yield mock_s3_server
mock_s3_server.kill()
@@ -307,16 +311,16 @@ def mock_s3_server(port_distributor: PortDistributor):
class PgProtocol:
"""Reusable connection logic"""
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.default_options = kwargs
def connstr(self, **kwargs) -> str:
def connstr(self, **kwargs: Any) -> str:
"""
Build a libpq connection string for the Postgres instance.
"""
return str(make_dsn(**self.conn_options(**kwargs)))
def conn_options(self, **kwargs):
def conn_options(self, **kwargs: Any) -> Dict[str, Any]:
"""
Construct a dictionary of connection options from default values and extra parameters.
An option can be dropped from the returning dictionary by None-valued extra parameter.
@@ -338,7 +342,7 @@ class PgProtocol:
return result
# autocommit=True here by default because that's what we need most of the time
def connect(self, autocommit=True, **kwargs) -> PgConnection:
def connect(self, autocommit: bool = True, **kwargs: Any) -> PgConnection:
"""
Connect to the node.
Returns psycopg2's connection object.
@@ -351,7 +355,7 @@ class PgProtocol:
return conn
@contextmanager
def cursor(self, autocommit=True, **kwargs):
def cursor(self, autocommit: bool = True, **kwargs: Any) -> Iterator[PgCursor]:
"""
Shorthand for pg.connect().cursor().
The cursor and connection are closed when the context is exited.
@@ -359,7 +363,7 @@ class PgProtocol:
with closing(self.connect(autocommit=autocommit, **kwargs)) as conn:
yield conn.cursor()
async def connect_async(self, **kwargs) -> asyncpg.Connection:
async def connect_async(self, **kwargs: Any) -> asyncpg.Connection:
"""
Connect to the node from async python.
Returns asyncpg's connection object.
@@ -413,10 +417,10 @@ class PgProtocol:
@dataclass
class AuthKeys:
pub: bytes
priv: bytes
pub: str
priv: str
def generate_management_token(self):
def generate_management_token(self) -> str:
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
@@ -427,9 +431,11 @@ class AuthKeys:
return token
def generate_tenant_token(self, tenant_id):
def generate_tenant_token(self, tenant_id: TenantId) -> str:
token = jwt.encode(
{"scope": "tenant", "tenant_id": str(tenant_id)}, self.priv, algorithm="RS256"
{"scope": "tenant", "tenant_id": str(tenant_id)},
self.priv,
algorithm="RS256",
)
if isinstance(token, bytes):
@@ -485,7 +491,7 @@ class MockS3Server:
@enum.unique
class RemoteStorageKind(enum.Enum):
class RemoteStorageKind(str, enum.Enum):
LOCAL_FS = "local_fs"
MOCK_S3 = "mock_s3"
REAL_S3 = "real_s3"
@@ -529,7 +535,7 @@ RemoteStorage = Union[LocalFsStorage, S3Storage]
# serialize as toml inline table
def remote_storage_to_toml_inline_table(remote_storage):
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
if isinstance(remote_storage, LocalFsStorage):
remote_storage_config = f"local_path='{remote_storage.root}'"
elif isinstance(remote_storage, S3Storage):
@@ -582,7 +588,7 @@ class NeonEnvBuilder:
safekeepers_enable_fsync: bool = False,
auth_enabled: bool = False,
rust_log_override: Optional[str] = None,
default_branch_name=DEFAULT_BRANCH_NAME,
default_branch_name: str = DEFAULT_BRANCH_NAME,
):
self.repo_dir = repo_dir
self.rust_log_override = rust_log_override
@@ -636,7 +642,7 @@ class NeonEnvBuilder:
else:
raise RuntimeError(f"Unknown storage type: {remote_storage_kind}")
def enable_local_fs_remote_storage(self, force_enable=True):
def enable_local_fs_remote_storage(self, force_enable: bool = True):
"""
Sets up the pageserver to use the local fs at the `test_dir/local_fs_remote_storage` path.
Errors, if the pageserver has some remote storage configuration already, unless `force_enable` is not set to `True`.
@@ -644,7 +650,7 @@ class NeonEnvBuilder:
assert force_enable or self.remote_storage is None, "remote storage is enabled already"
self.remote_storage = LocalFsStorage(Path(self.repo_dir / "local_fs_remote_storage"))
def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable=True):
def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable: bool = True):
"""
Sets up the pageserver to use the S3 mock server, creates the bucket, if it's not present already.
Starts up the mock server, if that does not run yet.
@@ -671,7 +677,7 @@ class NeonEnvBuilder:
secret_key=self.mock_s3_server.secret_key(),
)
def enable_real_s3_remote_storage(self, test_name: str, force_enable=True):
def enable_real_s3_remote_storage(self, test_name: str, force_enable: bool = True):
"""
Sets up configuration to use real s3 endpoint without mock server
"""
@@ -759,10 +765,15 @@ class NeonEnvBuilder:
log.info("deleted %s objects from remote storage", cnt)
def __enter__(self):
def __enter__(self) -> "NeonEnvBuilder":
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
):
# Stop all the nodes.
if self.env:
log.info("Cleaning up all storage and compute nodes")
@@ -909,7 +920,7 @@ class NeonEnv:
def get_safekeeper_connstrs(self) -> str:
"""Get list of safekeeper endpoints suitable for safekeepers GUC"""
return ",".join([f"localhost:{wa.port.pg}" for wa in self.safekeepers])
return ",".join(f"localhost:{wa.port.pg}" for wa in self.safekeepers)
def timeline_dir(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
"""Get a timeline directory's path based on the repo directory of the test environment"""
@@ -928,14 +939,14 @@ class NeonEnv:
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / "auth_public_key.pem").read_bytes()
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_bytes()
pub = (Path(self.repo_dir) / "auth_public_key.pem").read_text()
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text()
return AuthKeys(pub=pub, priv=priv)
@pytest.fixture(scope=shareable_scope)
def _shared_simple_env(
request: Any,
request: FixtureRequest,
port_distributor: PortDistributor,
mock_s3_server: MockS3Server,
default_broker: Etcd,
@@ -993,7 +1004,7 @@ def neon_simple_env(_shared_simple_env: NeonEnv) -> Iterator[NeonEnv]:
@pytest.fixture(scope="function")
def neon_env_builder(
test_output_dir,
test_output_dir: str,
port_distributor: PortDistributor,
mock_s3_server: MockS3Server,
neon_binpath: Path,
@@ -1059,7 +1070,7 @@ class PageserverHttpClient(requests.Session):
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]) -> None:
def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]):
self.is_testing_enabled_or_skip()
if isinstance(config_strings, tuple):
@@ -1189,7 +1200,6 @@ class PageserverHttpClient(requests.Session):
self.verbose_error(res)
res_json = res.json()
assert res_json is None
return res_json
def timeline_gc(
self, tenant_id: TenantId, timeline_id: TimelineId, gc_horizon: Optional[int]
@@ -1221,7 +1231,6 @@ class PageserverHttpClient(requests.Session):
self.verbose_error(res)
res_json = res.json()
assert res_json is None
return res_json
def timeline_get_lsn_by_timestamp(
self, tenant_id: TenantId, timeline_id: TimelineId, timestamp
@@ -1247,7 +1256,6 @@ class PageserverHttpClient(requests.Session):
self.verbose_error(res)
res_json = res.json()
assert res_json is None
return res_json
def get_metrics(self) -> str:
res = self.get(f"http://localhost:{self.port}/metrics")
@@ -1261,13 +1269,10 @@ class PageserverPort:
http: int
CREATE_TIMELINE_ID_EXTRACTOR = re.compile(
CREATE_TIMELINE_ID_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
r"^Created timeline '(?P<timeline_id>[^']+)'", re.MULTILINE
)
CREATE_TIMELINE_ID_EXTRACTOR = re.compile(
r"^Created timeline '(?P<timeline_id>[^']+)'", re.MULTILINE
)
TIMELINE_DATA_EXTRACTOR = re.compile(
TIMELINE_DATA_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
r"\s?(?P<branch_name>[^\s]+)\s\[(?P<timeline_id>[^\]]+)\]", re.MULTILINE
)
@@ -1560,7 +1565,7 @@ class NeonCli(AbstractNeonCli):
def pageserver_start(
self,
overrides=(),
overrides: Tuple[str, ...] = (),
) -> "subprocess.CompletedProcess[str]":
start_args = ["pageserver", "start", *overrides]
append_pageserver_param_overrides(
@@ -1718,7 +1723,7 @@ class NeonPageserver(PgProtocol):
self.config_override = config_override
self.version = env.get_pageserver_version()
def start(self, overrides=()) -> "NeonPageserver":
def start(self, overrides: Tuple[str, ...] = ()) -> "NeonPageserver":
"""
Start the page server.
`overrides` allows to add some config to this pageserver start.
@@ -1730,7 +1735,7 @@ class NeonPageserver(PgProtocol):
self.running = True
return self
def stop(self, immediate=False) -> "NeonPageserver":
def stop(self, immediate: bool = False) -> "NeonPageserver":
"""
Stop the page server.
Returns self.
@@ -1740,10 +1745,15 @@ class NeonPageserver(PgProtocol):
self.running = False
return self
def __enter__(self):
def __enter__(self) -> "NeonPageserver":
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
self.stop(immediate=True)
def is_testing_enabled_or_skip(self):
@@ -1855,7 +1865,7 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: str) -> PgBi
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init=True):
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True):
super().__init__(host="localhost", port=port, dbname="postgres")
self.pgdatadir = pgdatadir
self.pg_bin = pg_bin
@@ -1890,10 +1900,15 @@ class VanillaPostgres(PgProtocol):
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(os.path.join(self.pgdatadir, subdir))
def __enter__(self):
def __enter__(self) -> "VanillaPostgres":
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
if self.running:
self.stop()
@@ -1933,10 +1948,15 @@ class RemotePostgres(PgProtocol):
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
raise Exception("cannot get size of a Postgres instance")
def __enter__(self):
def __enter__(self) -> "RemotePostgres":
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
# do nothing
pass
@@ -1975,7 +1995,7 @@ class PSQL:
self.path = path
self.database_url = f"postgres://{host}:{port}/main?options=project%3Dgeneric-project-name"
async def run(self, query=None):
async def run(self, query: Optional[str] = None) -> asyncio.subprocess.Process:
run_args = [self.path, "--no-psqlrc", "--quiet", "--tuples-only", self.database_url]
if query is not None:
run_args += ["--command", query]
@@ -2008,7 +2028,7 @@ class NeonProxy(PgProtocol):
self._popen: Optional[subprocess.Popen[bytes]] = None
self.link_auth_uri_prefix = "http://dummy-uri"
def start(self) -> None:
def start(self):
"""
Starts a proxy with option '--auth-backend postgres' and a postgres instance already provided though '--auth-endpoint <postgress-instance>'."
"""
@@ -2026,7 +2046,7 @@ class NeonProxy(PgProtocol):
self._popen = subprocess.Popen(args)
self._wait_until_ready()
def start_with_link_auth(self) -> None:
def start_with_link_auth(self):
"""
Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'."
"""
@@ -2054,10 +2074,15 @@ class NeonProxy(PgProtocol):
request_result.raise_for_status()
return request_result.text
def __enter__(self):
def __enter__(self) -> "NeonProxy":
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
if self._popen is not None:
# NOTE the process will die when we're done with tests anyway, because
# it's a child process. This is mostly to clean up in between different tests.
@@ -2065,7 +2090,7 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope="function")
def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth."""
http_port = port_distributor.get_port()
proxy_port = port_distributor.get_port()
@@ -2076,7 +2101,9 @@ def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
@pytest.fixture(scope="function")
def static_proxy(vanilla_pg, port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
def static_proxy(
vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, neon_binpath: Path
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
@@ -2276,10 +2303,15 @@ class Postgres(PgProtocol):
return self
def __enter__(self):
def __enter__(self) -> "Postgres":
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
self.stop()
@@ -2288,7 +2320,7 @@ class PostgresFactory:
def __init__(self, env: NeonEnv):
self.env = env
self.num_instances = 0
self.num_instances: int = 0
self.instances: List[Postgres] = []
def create_start(
@@ -2383,7 +2415,7 @@ class Safekeeper:
break # success
return self
def stop(self, immediate=False) -> "Safekeeper":
def stop(self, immediate: bool = False) -> "Safekeeper":
log.info("Stopping safekeeper {}".format(self.id))
self.env.neon_cli.safekeeper_stop(self.id, immediate)
self.running = False
@@ -2598,7 +2630,7 @@ class Etcd:
self.handle.wait()
def get_test_output_dir(request: Any, top_output_dir: Path) -> Path:
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""Compute the working directory for an individual test."""
test_name = request.node.name
test_dir = top_output_dir / test_name.replace("/", "-")
@@ -2618,7 +2650,7 @@ def get_test_output_dir(request: Any, top_output_dir: Path) -> Path:
# this fixture ensures that the directory exists. That works because
# 'autouse' fixtures are run before other fixtures.
@pytest.fixture(scope="function", autouse=True)
def test_output_dir(request: Any, top_output_dir: Path) -> Iterator[Path]:
def test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
"""Create the working directory for an individual test."""
# one directory per test
@@ -2682,7 +2714,7 @@ def should_skip_file(filename: str) -> bool:
#
# Test helpers
#
def list_files_to_compare(pgdata_dir: Path):
def list_files_to_compare(pgdata_dir: Path) -> List[str]:
pgdata_files = []
for root, _file, filenames in os.walk(pgdata_dir):
for filename in filenames:

View File

@@ -1,3 +1,4 @@
from functools import cached_property
from typing import List
import pytest
@@ -13,7 +14,7 @@ class PgStatTable:
self.columns = columns
self.additional_query = filter_query
@property
@cached_property
def query(self) -> str:
return f"SELECT {','.join(self.columns)} FROM {self.table} {self.additional_query}"
@@ -55,6 +56,5 @@ def pg_stats_wal() -> List[PgStatTable]:
PgStatTable(
"pg_stat_wal",
["wal_records", "wal_fpi", "wal_bytes", "wal_buffers_full", "wal_write"],
"",
)
]

View File

@@ -1,4 +1,8 @@
from typing import Any, List
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
"""
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
@@ -9,15 +13,15 @@ Copied from here: https://docs.pytest.org/en/latest/example/simple.html
"""
def pytest_addoption(parser):
def pytest_addoption(parser: Parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def pytest_configure(config):
def pytest_configure(config: Config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
def pytest_collection_modifyitems(config, items):
def pytest_collection_modifyitems(config: Config, items: List[Any]):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return

View File

@@ -1,6 +1,8 @@
import random
from functools import total_ordering
from typing import Union
from typing import Any, Type, TypeVar, Union
T = TypeVar("T", bound="Id")
@total_ordering
@@ -17,31 +19,35 @@ class Lsn:
"""Convert lsn from hex notation to int."""
l, r = x.split("/")
self.lsn_int = (int(l, 16) << 32) + int(r, 16)
# FIXME: error if it doesn't look like a valid LSN
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
def __str__(self):
def __str__(self) -> str:
"""Convert lsn from int to standard hex notation."""
return "{:X}/{:X}".format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF)
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
def __repr__(self):
return 'Lsn("{:X}/{:X}")'.format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF)
def __repr__(self) -> str:
return f'Lsn("{str(self)}")'
def __int__(self):
def __int__(self) -> int:
return self.lsn_int
def __lt__(self, other: "Lsn") -> bool:
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int < other.lsn_int
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int == other.lsn_int
# Returns the difference between two Lsns, in bytes
def __sub__(self, other: "Lsn") -> int:
def __sub__(self, other: Any) -> int:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int - other.lsn_int
def __hash__(self):
def __hash__(self) -> int:
return hash(self.lsn_int)
@@ -57,7 +63,7 @@ class Id:
self.id = bytearray.fromhex(x)
assert len(self.id) == 16
def __str__(self):
def __str__(self) -> str:
return self.id.hex()
def __lt__(self, other) -> bool:
@@ -70,20 +76,20 @@ class Id:
return NotImplemented
return self.id == other.id
def __hash__(self):
def __hash__(self) -> int:
return hash(str(self.id))
@classmethod
def generate(cls):
def generate(cls: Type[T]) -> T:
"""Generate a random ID"""
return cls(random.randbytes(16).hex())
class TenantId(Id):
def __repr__(self):
def __repr__(self) -> str:
return f'`TenantId("{self.id.hex()}")'
class TimelineId(Id):
def __repr__(self):
def __repr__(self) -> str:
return f'TimelineId("{self.id.hex()}")'

View File

@@ -6,7 +6,7 @@ import subprocess
import tarfile
import time
from pathlib import Path
from typing import Any, Callable, List, Tuple, TypeVar
from typing import Any, Callable, Dict, List, Tuple, TypeVar
import allure # type: ignore
from fixtures.log_helper import log
@@ -30,11 +30,11 @@ def subprocess_capture(capture_dir: Path, cmd: List[str], **kwargs: Any) -> str:
If those files already exist, we will overwrite them.
Returns basepath for files with captured output.
"""
assert type(cmd) is list
base = os.path.basename(cmd[0]) + "_{}".format(global_counter())
assert isinstance(cmd, list)
base = f"{os.path.basename(cmd[0])}_{global_counter()}"
basepath = os.path.join(capture_dir, base)
stdout_filename = basepath + ".stdout"
stderr_filename = basepath + ".stderr"
stdout_filename = f"{basepath}.stdout"
stderr_filename = f"{basepath}.stderr"
try:
with open(stdout_filename, "w") as stdout_f:
@@ -64,7 +64,7 @@ def global_counter() -> int:
return _global_counter
def print_gc_result(row):
def print_gc_result(row: Dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
@@ -78,8 +78,7 @@ def etcd_path() -> Path:
path_output = shutil.which("etcd")
if path_output is None:
raise RuntimeError("etcd not found in PATH")
else:
return Path(path_output)
return Path(path_output)
def query_scalar(cur: cursor, query: str) -> Any:
@@ -124,7 +123,6 @@ def get_timeline_dir_size(path: Path) -> int:
# file is a delta layer
_ = parse_delta_layer(dir_entry.name)
sz += dir_entry.stat().st_size
continue
return sz
@@ -157,8 +155,8 @@ def get_scale_for_db(size_mb: int) -> int:
return round(0.06689 * size_mb - 0.5)
ATTACHMENT_NAME_REGEX = re.compile(
r".+\.log|.+\.stderr|.+\.stdout|.+\.filediff|.+\.metrics|flamegraph\.svg|regression\.diffs|.+\.html"
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
r"flamegraph\.svg|regression\.diffs|.+\.(?:log|stderr|stdout|filediff|metrics|html)"
)

View File

@@ -2,7 +2,7 @@ import statistics
import threading
import time
import timeit
from typing import Callable
from typing import Any, Callable, List
import pytest
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
@@ -197,7 +197,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte
if not isinstance(env, NeonCompare):
return
lsn_write_lags = []
lsn_write_lags: List[Any] = []
last_received_lsn = Lsn(0)
last_pg_flush_lsn = Lsn(0)
@@ -216,6 +216,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte
)
res = cur.fetchone()
assert isinstance(res, list)
lsn_write_lags.append(res[0])
curr_received_lsn = Lsn(res[3])

View File

@@ -24,7 +24,6 @@ if __name__ == "__main__":
if (v := os.environ.get(k, None)) is not None
}
loop = asyncio.new_event_loop()
row = loop.run_until_complete(run(**kwargs))
row = asyncio.run(run(**kwargs))
print(row[0])

View File

@@ -129,6 +129,7 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx
create_and_send_db_info(vanilla_pg, psql_session_id, link_proxy.mgmt_port)
assert proc.stdout is not None
out = (await proc.stdout.read()).decode("utf-8").strip()
assert out == "42"