From f720dd735e704b21a7ec702b39104335209c48c2 Mon Sep 17 00:00:00 2001 From: Vadim Kharitonov Date: Tue, 8 Nov 2022 10:51:34 +0100 Subject: [PATCH] Stricter mypy linters for `test_runner/fixtures/*` --- test_runner/fixtures/benchmark_fixture.py | 88 ++++---- test_runner/fixtures/compare_fixtures.py | 87 ++++---- test_runner/fixtures/metrics.py | 12 +- test_runner/fixtures/neon_fixtures.py | 188 ++++++++++-------- test_runner/fixtures/pg_stats.py | 4 +- test_runner/fixtures/slow.py | 10 +- test_runner/fixtures/types.py | 38 ++-- test_runner/fixtures/utils.py | 20 +- .../performance/test_wal_backpressure.py | 5 +- .../python/asyncpg/asyncpg_example.py | 3 +- test_runner/regress/test_proxy.py | 1 + 11 files changed, 255 insertions(+), 201 deletions(-) diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index b5565dab0f..27fb0a60b2 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -11,39 +11,37 @@ from datetime import datetime from pathlib import Path # Type-related stuff -from typing import Iterator, Optional +from typing import Callable, ClassVar, Iterator, Optional import pytest from _pytest.config import Config +from _pytest.config.argparsing import Parser from _pytest.terminal import TerminalReporter +from fixtures.neon_fixtures import NeonPageserver from fixtures.types import TenantId, TimelineId """ This file contains fixtures for micro-benchmarks. -To use, declare the 'zenbenchmark' fixture in the test function. Run the -bencmark, and then record the result by calling zenbenchmark.record. For example: +To use, declare the `zenbenchmark` fixture in the test function. Run the +bencmark, and then record the result by calling `zenbenchmark.record`. For example: -import timeit -from fixtures.neon_fixtures import NeonEnv - -def test_mybench(neon_simple_env: env, zenbenchmark): - - # Initialize the test - ... - - # Run the test, timing how long it takes - with zenbenchmark.record_duration('test_query'): - cur.execute('SELECT test_query(...)') - - # Record another measurement - zenbenchmark.record('speed_of_light', 300000, 'km/s') +>>> import timeit +>>> from fixtures.neon_fixtures import NeonEnv +>>> def test_mybench(neon_simple_env: NeonEnv, zenbenchmark): +... # Initialize the test +... ... +... # Run the test, timing how long it takes +... with zenbenchmark.record_duration('test_query'): +... cur.execute('SELECT test_query(...)') +... # Record another measurement +... zenbenchmark.record('speed_of_light', 300000, 'km/s') There's no need to import this file to use it. It should be declared as a plugin -inside conftest.py, and that makes it available to all tests. +inside `conftest.py`, and that makes it available to all tests. You can measure multiple things in one test, and record each one with a separate -call to zenbenchmark. For example, you could time the bulk loading that happens +call to `zenbenchmark`. For example, you could time the bulk loading that happens in the test initialization, or measure disk usage after the test query. """ @@ -117,7 +115,7 @@ class PgBenchRunResult: # tps = 309.281539 (without initial connection time) if line.startswith("tps = ") and ( "(excluding connections establishing)" in line - or "(without initial connection time)" + or "(without initial connection time)" in line ): tps = float(line.split()[2]) @@ -137,6 +135,17 @@ class PgBenchRunResult: @dataclasses.dataclass class PgBenchInitResult: + REGEX: ClassVar[re.Pattern] = re.compile( # type: ignore[type-arg] + r"done in (\d+\.\d+) s " + r"\(" + r"(?:drop tables (\d+\.\d+) s)?(?:, )?" + r"(?:create tables (\d+\.\d+) s)?(?:, )?" + r"(?:client-side generate (\d+\.\d+) s)?(?:, )?" + r"(?:vacuum (\d+\.\d+) s)?(?:, )?" + r"(?:primary keys (\d+\.\d+) s)?(?:, )?" + r"\)\." + ) + total: float drop_tables: Optional[float] create_tables: Optional[float] @@ -160,18 +169,7 @@ class PgBenchInitResult: last_line = stderr.splitlines()[-1] - regex = re.compile( - r"done in (\d+\.\d+) s " - r"\(" - r"(?:drop tables (\d+\.\d+) s)?(?:, )?" - r"(?:create tables (\d+\.\d+) s)?(?:, )?" - r"(?:client-side generate (\d+\.\d+) s)?(?:, )?" - r"(?:vacuum (\d+\.\d+) s)?(?:, )?" - r"(?:primary keys (\d+\.\d+) s)?(?:, )?" - r"\)\." - ) - - if (m := regex.match(last_line)) is not None: + if (m := cls.REGEX.match(last_line)) is not None: total, drop_tables, create_tables, client_side_generate, vacuum, primary_keys = [ float(v) for v in m.groups() if v is not None ] @@ -208,7 +206,7 @@ class NeonBenchmarker: function by the zenbenchmark fixture """ - def __init__(self, property_recorder): + def __init__(self, property_recorder: Callable[[str, object], None]): # property recorder here is a pytest fixture provided by junitxml module # https://docs.pytest.org/en/6.2.x/reference.html#pytest.junitxml.record_property self.property_recorder = property_recorder @@ -236,7 +234,7 @@ class NeonBenchmarker: ) @contextmanager - def record_duration(self, metric_name: str): + def record_duration(self, metric_name: str) -> Iterator[None]: """ Record a duration. Usage: @@ -337,21 +335,21 @@ class NeonBenchmarker: f"{prefix}.{metric}", value, unit="s", report=MetricReport.LOWER_IS_BETTER ) - def get_io_writes(self, pageserver) -> int: + def get_io_writes(self, pageserver: NeonPageserver) -> int: """ Fetch the "cumulative # of bytes written" metric from the pageserver """ metric_name = r'libmetrics_disk_io_bytes_total{io_operation="write"}' return self.get_int_counter_value(pageserver, metric_name) - def get_peak_mem(self, pageserver) -> int: + def get_peak_mem(self, pageserver: NeonPageserver) -> int: """ Fetch the "maxrss" metric from the pageserver """ metric_name = r"libmetrics_maxrss_kb" return self.get_int_counter_value(pageserver, metric_name) - def get_int_counter_value(self, pageserver, metric_name) -> int: + def get_int_counter_value(self, pageserver: NeonPageserver, metric_name: str) -> int: """Fetch the value of given int counter from pageserver metrics.""" # TODO: If we start to collect more of the prometheus metrics in the # performance test suite like this, we should refactor this to load and @@ -365,7 +363,9 @@ class NeonBenchmarker: assert matches, f"metric {metric_name} not found" return int(round(float(matches.group(1)))) - def get_timeline_size(self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId): + def get_timeline_size( + self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId + ) -> int: """ Calculate the on-disk size of a timeline """ @@ -379,7 +379,9 @@ class NeonBenchmarker: return totalbytes @contextmanager - def record_pageserver_writes(self, pageserver, metric_name): + def record_pageserver_writes( + self, pageserver: NeonPageserver, metric_name: str + ) -> Iterator[None]: """ Record bytes written by the pageserver during a test. """ @@ -396,7 +398,7 @@ class NeonBenchmarker: @pytest.fixture(scope="function") -def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]: +def zenbenchmark(record_property: Callable[[str, object], None]) -> Iterator[NeonBenchmarker]: """ This is a python decorator for benchmark fixtures. It contains functions for recording measurements, and prints them out at the end. @@ -405,7 +407,7 @@ def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]: yield benchmarker -def pytest_addoption(parser): +def pytest_addoption(parser: Parser): parser.addoption( "--out-dir", dest="out_dir", @@ -429,7 +431,9 @@ def get_out_path(target_dir: Path, revision: str) -> Path: # Hook to print the results at the end @pytest.hookimpl(hookwrapper=True) -def pytest_terminal_summary(terminalreporter: TerminalReporter, exitstatus: int, config: Config): +def pytest_terminal_summary( + terminalreporter: TerminalReporter, exitstatus: int, config: Config +) -> Iterator[None]: yield revision = os.getenv("GITHUB_SHA", "local") platform = os.getenv("PLATFORM", "local") diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 2d36d90bd6..291f924379 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import _GeneratorContextManager, contextmanager # Type-related stuff -from typing import Dict, List +from typing import Dict, Iterator, List import pytest +from _pytest.fixtures import FixtureRequest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.neon_fixtures import NeonEnv, PgBin, PgProtocol, RemotePostgres, VanillaPostgres from fixtures.pg_stats import PgStatTable @@ -28,19 +29,20 @@ class PgCompare(ABC): pass @property + @abstractmethod def zenbenchmark(self) -> NeonBenchmarker: pass @abstractmethod - def flush(self) -> None: + def flush(self): pass @abstractmethod - def report_peak_memory_use(self) -> None: + def report_peak_memory_use(self): pass @abstractmethod - def report_size(self) -> None: + def report_size(self): pass @contextmanager @@ -54,7 +56,7 @@ class PgCompare(ABC): pass @contextmanager - def record_pg_stats(self, pg_stats: List[PgStatTable]): + def record_pg_stats(self, pg_stats: List[PgStatTable]) -> Iterator[None]: init_data = self._retrieve_pg_stats(pg_stats) yield @@ -84,7 +86,11 @@ class NeonCompare(PgCompare): """PgCompare interface for the neon stack.""" def __init__( - self, zenbenchmark: NeonBenchmarker, neon_simple_env: NeonEnv, pg_bin: PgBin, branch_name + self, + zenbenchmark: NeonBenchmarker, + neon_simple_env: NeonEnv, + pg_bin: PgBin, + branch_name: str, ): self.env = neon_simple_env self._zenbenchmark = zenbenchmark @@ -97,15 +103,15 @@ class NeonCompare(PgCompare): self.timeline = self.pg.safe_psql("SHOW neon.timeline_id")[0][0] @property - def pg(self): + def pg(self) -> PgProtocol: return self._pg @property - def zenbenchmark(self): + def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property - def pg_bin(self): + def pg_bin(self) -> PgBin: return self._pg_bin def flush(self): @@ -114,7 +120,7 @@ class NeonCompare(PgCompare): def compact(self): self.pageserver_http_client.timeline_compact(self.env.initial_tenant, self.timeline) - def report_peak_memory_use(self) -> None: + def report_peak_memory_use(self): self.zenbenchmark.record( "peak_mem", self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024, @@ -122,7 +128,7 @@ class NeonCompare(PgCompare): report=MetricReport.LOWER_IS_BETTER, ) - def report_size(self) -> None: + def report_size(self): timeline_size = self.zenbenchmark.get_timeline_size( self.env.repo_dir, self.env.initial_tenant, self.timeline ) @@ -144,17 +150,17 @@ class NeonCompare(PgCompare): "num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER ) - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) - def record_duration(self, out_name): + def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) class VanillaCompare(PgCompare): """PgCompare interface for vanilla postgres.""" - def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres): + def __init__(self, zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres): self._pg = vanilla_pg self._zenbenchmark = zenbenchmark vanilla_pg.configure( @@ -170,24 +176,24 @@ class VanillaCompare(PgCompare): self.cur = self.conn.cursor() @property - def pg(self): + def pg(self) -> PgProtocol: return self._pg @property - def zenbenchmark(self): + def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property - def pg_bin(self): + def pg_bin(self) -> PgBin: return self._pg.pg_bin def flush(self): self.cur.execute("checkpoint") - def report_peak_memory_use(self) -> None: + def report_peak_memory_use(self): pass # TODO find something - def report_size(self) -> None: + def report_size(self): data_size = self.pg.get_subdir_size("base") self.zenbenchmark.record( "data_size", data_size / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER @@ -198,17 +204,17 @@ class VanillaCompare(PgCompare): ) @contextmanager - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing - def record_duration(self, out_name): + def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) class RemoteCompare(PgCompare): """PgCompare interface for a remote postgres instance.""" - def __init__(self, zenbenchmark, remote_pg: RemotePostgres): + def __init__(self, zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres): self._pg = remote_pg self._zenbenchmark = zenbenchmark @@ -217,55 +223,60 @@ class RemoteCompare(PgCompare): self.cur = self.conn.cursor() @property - def pg(self): + def pg(self) -> PgProtocol: return self._pg @property - def zenbenchmark(self): + def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property - def pg_bin(self): + def pg_bin(self) -> PgBin: return self._pg.pg_bin def flush(self): # TODO: flush the remote pageserver pass - def report_peak_memory_use(self) -> None: + def report_peak_memory_use(self): # TODO: get memory usage from remote pageserver pass - def report_size(self) -> None: + def report_size(self): # TODO: get storage size from remote pageserver pass @contextmanager - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing - def record_duration(self, out_name): + def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @pytest.fixture(scope="function") -def neon_compare(request, zenbenchmark, pg_bin, neon_simple_env) -> NeonCompare: +def neon_compare( + request: FixtureRequest, + zenbenchmark: NeonBenchmarker, + pg_bin: PgBin, + neon_simple_env: NeonEnv, +) -> NeonCompare: branch_name = request.node.name return NeonCompare(zenbenchmark, neon_simple_env, pg_bin, branch_name) @pytest.fixture(scope="function") -def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare: +def vanilla_compare(zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres) -> VanillaCompare: return VanillaCompare(zenbenchmark, vanilla_pg) @pytest.fixture(scope="function") -def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare: +def remote_compare(zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres) -> RemoteCompare: return RemoteCompare(zenbenchmark, remote_pg) @pytest.fixture(params=["vanilla_compare", "neon_compare"], ids=["vanilla", "neon"]) -def neon_with_baseline(request) -> PgCompare: +def neon_with_baseline(request: FixtureRequest) -> PgCompare: """Parameterized fixture that helps compare neon against vanilla postgres. A test that uses this fixture turns into a parameterized test that runs against: @@ -286,8 +297,6 @@ def neon_with_baseline(request) -> PgCompare: implementation-specific logic is widely useful across multiple tests, it might make sense to add methods to the PgCompare class. """ - fixture = request.getfixturevalue(request.param) - if isinstance(fixture, PgCompare): - return fixture - else: - raise AssertionError(f"test error: fixture {request.param} is not PgCompare") + fixture = request.getfixturevalue(request.param) # type: ignore + assert isinstance(fixture, PgCompare), f"test error: fixture {fixture} is not PgCompare" + return fixture diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index 62e3cbbe99..86ab4425ed 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from prometheus_client.parser import text_string_to_metric_families from prometheus_client.samples import Sample @@ -23,13 +23,13 @@ class Metrics: pass return res - def query_one(self, name: str, filter: Dict[str, str] = {}) -> Sample: - res = self.query_all(name, filter) + def query_one(self, name: str, filter: Optional[Dict[str, str]] = None) -> Sample: + res = self.query_all(name, filter or {}) assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}" return res[0] -def parse_metrics(text: str, name: str = ""): +def parse_metrics(text: str, name: str = "") -> Metrics: metrics = Metrics(name) gen = text_string_to_metric_families(text) for family in gen: @@ -39,7 +39,7 @@ def parse_metrics(text: str, name: str = ""): return metrics -PAGESERVER_PER_TENANT_METRICS = [ +PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = ( "pageserver_current_logical_size", "pageserver_current_physical_size", "pageserver_getpage_reconstruct_seconds_bucket", @@ -62,4 +62,4 @@ PAGESERVER_PER_TENANT_METRICS = [ "pageserver_wait_lsn_seconds_sum", "pageserver_created_persistent_files_total", "pageserver_written_persistent_bytes_total", -] +) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 7a46a08f08..f68c6a25db 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -19,7 +19,8 @@ from dataclasses import dataclass, field from enum import Flag, auto from functools import cached_property from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from types import TracebackType +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import asyncpg import backoff # type: ignore @@ -28,16 +29,18 @@ import jwt import psycopg2 import pytest import requests +from _pytest.config import Config +from _pytest.fixtures import FixtureRequest from fixtures.log_helper import log from fixtures.types import Lsn, TenantId, TimelineId +from fixtures.utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture # Type-related stuff from psycopg2.extensions import connection as PgConnection +from psycopg2.extensions import cursor as PgCursor from psycopg2.extensions import make_dsn, parse_dsn from typing_extensions import Literal -from .utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture - """ This file contains pytest fixtures. A fixture is a test resource that can be summoned by placing its name in the test's arguments. @@ -57,15 +60,15 @@ put directly-importable functions into utils.py or another separate file. Env = Dict[str, str] -DEFAULT_OUTPUT_DIR = "test_output" -DEFAULT_BRANCH_NAME = "main" -DEFAULT_PG_VERSION_DEFAULT = "14" +DEFAULT_OUTPUT_DIR: str = "test_output" +DEFAULT_BRANCH_NAME: str = "main" +DEFAULT_PG_VERSION_DEFAULT: str = "14" -BASE_PORT = 15000 -WORKER_PORT_NUM = 1000 +BASE_PORT: int = 15000 +WORKER_PORT_NUM: int = 1000 -def pytest_configure(config): +def pytest_configure(config: Config): """ Check that we do not overflow available ports range. """ @@ -154,14 +157,14 @@ def versioned_pg_distrib_dir(pg_distrib_dir: Path, pg_version: str) -> Iterator[ if not psql_bin_path.exists(): raise Exception(f"psql not found at '{psql_bin_path}'") else: - if not postgres_bin_path.exists: + if not postgres_bin_path.exists(): raise Exception(f"postgres not found at '{postgres_bin_path}'") log.info(f"versioned_pg_distrib_dir is {versioned_dir}") yield versioned_dir -def shareable_scope(fixture_name, config) -> Literal["session", "function"]: +def shareable_scope(fixture_name: str, config: Config) -> Literal["session", "function"]: """Return either session of function scope, depending on TEST_SHARED_FIXTURES envvar. This function can be used as a scope like this: @@ -173,7 +176,7 @@ def shareable_scope(fixture_name, config) -> Literal["session", "function"]: @pytest.fixture(scope="session") -def worker_seq_no(worker_id: str): +def worker_seq_no(worker_id: str) -> int: # worker_id is a pytest-xdist fixture # it can be master or gw # parse it to always get a number @@ -184,7 +187,7 @@ def worker_seq_no(worker_id: str): @pytest.fixture(scope="session") -def worker_base_port(worker_seq_no: int): +def worker_base_port(worker_seq_no: int) -> int: # so we divide ports in ranges of 100 ports # so workers have disjoint set of ports for services return BASE_PORT + worker_seq_no * WORKER_PORT_NUM @@ -234,10 +237,9 @@ class PortDistributor: for port in self.iterator: if can_bind("localhost", port): return port - else: - raise RuntimeError( - "port range configured for test is exhausted, consider enlarging the range" - ) + raise RuntimeError( + "port range configured for test is exhausted, consider enlarging the range" + ) def replace_with_new_port(self, value: Union[int, str]) -> Union[int, str]: """ @@ -273,12 +275,14 @@ class PortDistributor: @pytest.fixture(scope="session") -def port_distributor(worker_base_port): +def port_distributor(worker_base_port: int) -> PortDistributor: return PortDistributor(base_port=worker_base_port, port_number=WORKER_PORT_NUM) @pytest.fixture(scope="session") -def default_broker(request: Any, port_distributor: PortDistributor, top_output_dir: Path): +def default_broker( + request: FixtureRequest, port_distributor: PortDistributor, top_output_dir: Path +) -> Iterator[Etcd]: client_port = port_distributor.get_port() # multiple pytest sessions could get launched in parallel, get them different datadirs etcd_datadir = get_test_output_dir(request, top_output_dir) / f"etcd_datadir_{client_port}" @@ -293,12 +297,12 @@ def default_broker(request: Any, port_distributor: PortDistributor, top_output_d @pytest.fixture(scope="session") -def run_id(): +def run_id() -> Iterator[uuid.UUID]: yield uuid.uuid4() @pytest.fixture(scope="session") -def mock_s3_server(port_distributor: PortDistributor): +def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]: mock_s3_server = MockS3Server(port_distributor.get_port()) yield mock_s3_server mock_s3_server.kill() @@ -307,16 +311,16 @@ def mock_s3_server(port_distributor: PortDistributor): class PgProtocol: """Reusable connection logic""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.default_options = kwargs - def connstr(self, **kwargs) -> str: + def connstr(self, **kwargs: Any) -> str: """ Build a libpq connection string for the Postgres instance. """ return str(make_dsn(**self.conn_options(**kwargs))) - def conn_options(self, **kwargs): + def conn_options(self, **kwargs: Any) -> Dict[str, Any]: """ Construct a dictionary of connection options from default values and extra parameters. An option can be dropped from the returning dictionary by None-valued extra parameter. @@ -338,7 +342,7 @@ class PgProtocol: return result # autocommit=True here by default because that's what we need most of the time - def connect(self, autocommit=True, **kwargs) -> PgConnection: + def connect(self, autocommit: bool = True, **kwargs: Any) -> PgConnection: """ Connect to the node. Returns psycopg2's connection object. @@ -351,7 +355,7 @@ class PgProtocol: return conn @contextmanager - def cursor(self, autocommit=True, **kwargs): + def cursor(self, autocommit: bool = True, **kwargs: Any) -> Iterator[PgCursor]: """ Shorthand for pg.connect().cursor(). The cursor and connection are closed when the context is exited. @@ -359,7 +363,7 @@ class PgProtocol: with closing(self.connect(autocommit=autocommit, **kwargs)) as conn: yield conn.cursor() - async def connect_async(self, **kwargs) -> asyncpg.Connection: + async def connect_async(self, **kwargs: Any) -> asyncpg.Connection: """ Connect to the node from async python. Returns asyncpg's connection object. @@ -413,10 +417,10 @@ class PgProtocol: @dataclass class AuthKeys: - pub: bytes - priv: bytes + pub: str + priv: str - def generate_management_token(self): + def generate_management_token(self) -> str: token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256") # jwt.encode can return 'bytes' or 'str', depending on Python version or type @@ -427,9 +431,11 @@ class AuthKeys: return token - def generate_tenant_token(self, tenant_id): + def generate_tenant_token(self, tenant_id: TenantId) -> str: token = jwt.encode( - {"scope": "tenant", "tenant_id": str(tenant_id)}, self.priv, algorithm="RS256" + {"scope": "tenant", "tenant_id": str(tenant_id)}, + self.priv, + algorithm="RS256", ) if isinstance(token, bytes): @@ -485,7 +491,7 @@ class MockS3Server: @enum.unique -class RemoteStorageKind(enum.Enum): +class RemoteStorageKind(str, enum.Enum): LOCAL_FS = "local_fs" MOCK_S3 = "mock_s3" REAL_S3 = "real_s3" @@ -529,7 +535,7 @@ RemoteStorage = Union[LocalFsStorage, S3Storage] # serialize as toml inline table -def remote_storage_to_toml_inline_table(remote_storage): +def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str: if isinstance(remote_storage, LocalFsStorage): remote_storage_config = f"local_path='{remote_storage.root}'" elif isinstance(remote_storage, S3Storage): @@ -582,7 +588,7 @@ class NeonEnvBuilder: safekeepers_enable_fsync: bool = False, auth_enabled: bool = False, rust_log_override: Optional[str] = None, - default_branch_name=DEFAULT_BRANCH_NAME, + default_branch_name: str = DEFAULT_BRANCH_NAME, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -636,7 +642,7 @@ class NeonEnvBuilder: else: raise RuntimeError(f"Unknown storage type: {remote_storage_kind}") - def enable_local_fs_remote_storage(self, force_enable=True): + def enable_local_fs_remote_storage(self, force_enable: bool = True): """ Sets up the pageserver to use the local fs at the `test_dir/local_fs_remote_storage` path. Errors, if the pageserver has some remote storage configuration already, unless `force_enable` is not set to `True`. @@ -644,7 +650,7 @@ class NeonEnvBuilder: assert force_enable or self.remote_storage is None, "remote storage is enabled already" self.remote_storage = LocalFsStorage(Path(self.repo_dir / "local_fs_remote_storage")) - def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable=True): + def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable: bool = True): """ Sets up the pageserver to use the S3 mock server, creates the bucket, if it's not present already. Starts up the mock server, if that does not run yet. @@ -671,7 +677,7 @@ class NeonEnvBuilder: secret_key=self.mock_s3_server.secret_key(), ) - def enable_real_s3_remote_storage(self, test_name: str, force_enable=True): + def enable_real_s3_remote_storage(self, test_name: str, force_enable: bool = True): """ Sets up configuration to use real s3 endpoint without mock server """ @@ -759,10 +765,15 @@ class NeonEnvBuilder: log.info("deleted %s objects from remote storage", cnt) - def __enter__(self): + def __enter__(self) -> "NeonEnvBuilder": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ): # Stop all the nodes. if self.env: log.info("Cleaning up all storage and compute nodes") @@ -909,7 +920,7 @@ class NeonEnv: def get_safekeeper_connstrs(self) -> str: """Get list of safekeeper endpoints suitable for safekeepers GUC""" - return ",".join([f"localhost:{wa.port.pg}" for wa in self.safekeepers]) + return ",".join(f"localhost:{wa.port.pg}" for wa in self.safekeepers) def timeline_dir(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path: """Get a timeline directory's path based on the repo directory of the test environment""" @@ -928,14 +939,14 @@ class NeonEnv: @cached_property def auth_keys(self) -> AuthKeys: - pub = (Path(self.repo_dir) / "auth_public_key.pem").read_bytes() - priv = (Path(self.repo_dir) / "auth_private_key.pem").read_bytes() + pub = (Path(self.repo_dir) / "auth_public_key.pem").read_text() + priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text() return AuthKeys(pub=pub, priv=priv) @pytest.fixture(scope=shareable_scope) def _shared_simple_env( - request: Any, + request: FixtureRequest, port_distributor: PortDistributor, mock_s3_server: MockS3Server, default_broker: Etcd, @@ -993,7 +1004,7 @@ def neon_simple_env(_shared_simple_env: NeonEnv) -> Iterator[NeonEnv]: @pytest.fixture(scope="function") def neon_env_builder( - test_output_dir, + test_output_dir: str, port_distributor: PortDistributor, mock_s3_server: MockS3Server, neon_binpath: Path, @@ -1059,7 +1070,7 @@ class PageserverHttpClient(requests.Session): def check_status(self): self.get(f"http://localhost:{self.port}/v1/status").raise_for_status() - def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]) -> None: + def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]): self.is_testing_enabled_or_skip() if isinstance(config_strings, tuple): @@ -1189,7 +1200,6 @@ class PageserverHttpClient(requests.Session): self.verbose_error(res) res_json = res.json() assert res_json is None - return res_json def timeline_gc( self, tenant_id: TenantId, timeline_id: TimelineId, gc_horizon: Optional[int] @@ -1221,7 +1231,6 @@ class PageserverHttpClient(requests.Session): self.verbose_error(res) res_json = res.json() assert res_json is None - return res_json def timeline_get_lsn_by_timestamp( self, tenant_id: TenantId, timeline_id: TimelineId, timestamp @@ -1247,7 +1256,6 @@ class PageserverHttpClient(requests.Session): self.verbose_error(res) res_json = res.json() assert res_json is None - return res_json def get_metrics(self) -> str: res = self.get(f"http://localhost:{self.port}/metrics") @@ -1261,13 +1269,10 @@ class PageserverPort: http: int -CREATE_TIMELINE_ID_EXTRACTOR = re.compile( +CREATE_TIMELINE_ID_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg] r"^Created timeline '(?P[^']+)'", re.MULTILINE ) -CREATE_TIMELINE_ID_EXTRACTOR = re.compile( - r"^Created timeline '(?P[^']+)'", re.MULTILINE -) -TIMELINE_DATA_EXTRACTOR = re.compile( +TIMELINE_DATA_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg] r"\s?(?P[^\s]+)\s\[(?P[^\]]+)\]", re.MULTILINE ) @@ -1560,7 +1565,7 @@ class NeonCli(AbstractNeonCli): def pageserver_start( self, - overrides=(), + overrides: Tuple[str, ...] = (), ) -> "subprocess.CompletedProcess[str]": start_args = ["pageserver", "start", *overrides] append_pageserver_param_overrides( @@ -1718,7 +1723,7 @@ class NeonPageserver(PgProtocol): self.config_override = config_override self.version = env.get_pageserver_version() - def start(self, overrides=()) -> "NeonPageserver": + def start(self, overrides: Tuple[str, ...] = ()) -> "NeonPageserver": """ Start the page server. `overrides` allows to add some config to this pageserver start. @@ -1730,7 +1735,7 @@ class NeonPageserver(PgProtocol): self.running = True return self - def stop(self, immediate=False) -> "NeonPageserver": + def stop(self, immediate: bool = False) -> "NeonPageserver": """ Stop the page server. Returns self. @@ -1740,10 +1745,15 @@ class NeonPageserver(PgProtocol): self.running = False return self - def __enter__(self): + def __enter__(self) -> "NeonPageserver": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): self.stop(immediate=True) def is_testing_enabled_or_skip(self): @@ -1855,7 +1865,7 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: str) -> PgBi class VanillaPostgres(PgProtocol): - def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init=True): + def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True): super().__init__(host="localhost", port=port, dbname="postgres") self.pgdatadir = pgdatadir self.pg_bin = pg_bin @@ -1890,10 +1900,15 @@ class VanillaPostgres(PgProtocol): """Return size of pgdatadir subdirectory in bytes.""" return get_dir_size(os.path.join(self.pgdatadir, subdir)) - def __enter__(self): + def __enter__(self) -> "VanillaPostgres": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): if self.running: self.stop() @@ -1933,10 +1948,15 @@ class RemotePostgres(PgProtocol): # See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE raise Exception("cannot get size of a Postgres instance") - def __enter__(self): + def __enter__(self) -> "RemotePostgres": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): # do nothing pass @@ -1975,7 +1995,7 @@ class PSQL: self.path = path self.database_url = f"postgres://{host}:{port}/main?options=project%3Dgeneric-project-name" - async def run(self, query=None): + async def run(self, query: Optional[str] = None) -> asyncio.subprocess.Process: run_args = [self.path, "--no-psqlrc", "--quiet", "--tuples-only", self.database_url] if query is not None: run_args += ["--command", query] @@ -2008,7 +2028,7 @@ class NeonProxy(PgProtocol): self._popen: Optional[subprocess.Popen[bytes]] = None self.link_auth_uri_prefix = "http://dummy-uri" - def start(self) -> None: + def start(self): """ Starts a proxy with option '--auth-backend postgres' and a postgres instance already provided though '--auth-endpoint '." """ @@ -2026,7 +2046,7 @@ class NeonProxy(PgProtocol): self._popen = subprocess.Popen(args) self._wait_until_ready() - def start_with_link_auth(self) -> None: + def start_with_link_auth(self): """ Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'." """ @@ -2054,10 +2074,15 @@ class NeonProxy(PgProtocol): request_result.raise_for_status() return request_result.text - def __enter__(self): + def __enter__(self) -> "NeonProxy": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): if self._popen is not None: # NOTE the process will die when we're done with tests anyway, because # it's a child process. This is mostly to clean up in between different tests. @@ -2065,7 +2090,7 @@ class NeonProxy(PgProtocol): @pytest.fixture(scope="function") -def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]: +def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]: """Neon proxy that routes through link auth.""" http_port = port_distributor.get_port() proxy_port = port_distributor.get_port() @@ -2076,7 +2101,9 @@ def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]: @pytest.fixture(scope="function") -def static_proxy(vanilla_pg, port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]: +def static_proxy( + vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, neon_binpath: Path +) -> Iterator[NeonProxy]: """Neon proxy that routes directly to vanilla postgres.""" # For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql` @@ -2276,10 +2303,15 @@ class Postgres(PgProtocol): return self - def __enter__(self): + def __enter__(self) -> "Postgres": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): self.stop() @@ -2288,7 +2320,7 @@ class PostgresFactory: def __init__(self, env: NeonEnv): self.env = env - self.num_instances = 0 + self.num_instances: int = 0 self.instances: List[Postgres] = [] def create_start( @@ -2383,7 +2415,7 @@ class Safekeeper: break # success return self - def stop(self, immediate=False) -> "Safekeeper": + def stop(self, immediate: bool = False) -> "Safekeeper": log.info("Stopping safekeeper {}".format(self.id)) self.env.neon_cli.safekeeper_stop(self.id, immediate) self.running = False @@ -2598,7 +2630,7 @@ class Etcd: self.handle.wait() -def get_test_output_dir(request: Any, top_output_dir: Path) -> Path: +def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: """Compute the working directory for an individual test.""" test_name = request.node.name test_dir = top_output_dir / test_name.replace("/", "-") @@ -2618,7 +2650,7 @@ def get_test_output_dir(request: Any, top_output_dir: Path) -> Path: # this fixture ensures that the directory exists. That works because # 'autouse' fixtures are run before other fixtures. @pytest.fixture(scope="function", autouse=True) -def test_output_dir(request: Any, top_output_dir: Path) -> Iterator[Path]: +def test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Iterator[Path]: """Create the working directory for an individual test.""" # one directory per test @@ -2682,7 +2714,7 @@ def should_skip_file(filename: str) -> bool: # # Test helpers # -def list_files_to_compare(pgdata_dir: Path): +def list_files_to_compare(pgdata_dir: Path) -> List[str]: pgdata_files = [] for root, _file, filenames in os.walk(pgdata_dir): for filename in filenames: diff --git a/test_runner/fixtures/pg_stats.py b/test_runner/fixtures/pg_stats.py index b2e6886eb3..adb3a7730e 100644 --- a/test_runner/fixtures/pg_stats.py +++ b/test_runner/fixtures/pg_stats.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import List import pytest @@ -13,7 +14,7 @@ class PgStatTable: self.columns = columns self.additional_query = filter_query - @property + @cached_property def query(self) -> str: return f"SELECT {','.join(self.columns)} FROM {self.table} {self.additional_query}" @@ -55,6 +56,5 @@ def pg_stats_wal() -> List[PgStatTable]: PgStatTable( "pg_stat_wal", ["wal_records", "wal_fpi", "wal_bytes", "wal_buffers_full", "wal_write"], - "", ) ] diff --git a/test_runner/fixtures/slow.py b/test_runner/fixtures/slow.py index 94199ae785..ae0e87b553 100644 --- a/test_runner/fixtures/slow.py +++ b/test_runner/fixtures/slow.py @@ -1,4 +1,8 @@ +from typing import Any, List + import pytest +from _pytest.config import Config +from _pytest.config.argparsing import Parser """ This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow @@ -9,15 +13,15 @@ Copied from here: https://docs.pytest.org/en/latest/example/simple.html """ -def pytest_addoption(parser): +def pytest_addoption(parser: Parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") -def pytest_configure(config): +def pytest_configure(config: Config): config.addinivalue_line("markers", "slow: mark test as slow to run") -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config: Config, items: List[Any]): if config.getoption("--runslow"): # --runslow given in cli: do not skip slow tests return diff --git a/test_runner/fixtures/types.py b/test_runner/fixtures/types.py index de2e131b79..2bb962d44a 100644 --- a/test_runner/fixtures/types.py +++ b/test_runner/fixtures/types.py @@ -1,6 +1,8 @@ import random from functools import total_ordering -from typing import Union +from typing import Any, Type, TypeVar, Union + +T = TypeVar("T", bound="Id") @total_ordering @@ -17,31 +19,35 @@ class Lsn: """Convert lsn from hex notation to int.""" l, r = x.split("/") self.lsn_int = (int(l, 16) << 32) + int(r, 16) - # FIXME: error if it doesn't look like a valid LSN + assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF - def __str__(self): + def __str__(self) -> str: """Convert lsn from int to standard hex notation.""" - return "{:X}/{:X}".format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF) + return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}" - def __repr__(self): - return 'Lsn("{:X}/{:X}")'.format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF) + def __repr__(self) -> str: + return f'Lsn("{str(self)}")' - def __int__(self): + def __int__(self) -> int: return self.lsn_int - def __lt__(self, other: "Lsn") -> bool: + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Lsn): + return NotImplemented return self.lsn_int < other.lsn_int - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int == other.lsn_int # Returns the difference between two Lsns, in bytes - def __sub__(self, other: "Lsn") -> int: + def __sub__(self, other: Any) -> int: + if not isinstance(other, Lsn): + return NotImplemented return self.lsn_int - other.lsn_int - def __hash__(self): + def __hash__(self) -> int: return hash(self.lsn_int) @@ -57,7 +63,7 @@ class Id: self.id = bytearray.fromhex(x) assert len(self.id) == 16 - def __str__(self): + def __str__(self) -> str: return self.id.hex() def __lt__(self, other) -> bool: @@ -70,20 +76,20 @@ class Id: return NotImplemented return self.id == other.id - def __hash__(self): + def __hash__(self) -> int: return hash(str(self.id)) @classmethod - def generate(cls): + def generate(cls: Type[T]) -> T: """Generate a random ID""" return cls(random.randbytes(16).hex()) class TenantId(Id): - def __repr__(self): + def __repr__(self) -> str: return f'`TenantId("{self.id.hex()}")' class TimelineId(Id): - def __repr__(self): + def __repr__(self) -> str: return f'TimelineId("{self.id.hex()}")' diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index b04e02d3b8..506fe6f9da 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -6,7 +6,7 @@ import subprocess import tarfile import time from pathlib import Path -from typing import Any, Callable, List, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Tuple, TypeVar import allure # type: ignore from fixtures.log_helper import log @@ -30,11 +30,11 @@ def subprocess_capture(capture_dir: Path, cmd: List[str], **kwargs: Any) -> str: If those files already exist, we will overwrite them. Returns basepath for files with captured output. """ - assert type(cmd) is list - base = os.path.basename(cmd[0]) + "_{}".format(global_counter()) + assert isinstance(cmd, list) + base = f"{os.path.basename(cmd[0])}_{global_counter()}" basepath = os.path.join(capture_dir, base) - stdout_filename = basepath + ".stdout" - stderr_filename = basepath + ".stderr" + stdout_filename = f"{basepath}.stdout" + stderr_filename = f"{basepath}.stderr" try: with open(stdout_filename, "w") as stdout_f: @@ -64,7 +64,7 @@ def global_counter() -> int: return _global_counter -def print_gc_result(row): +def print_gc_result(row: Dict[str, Any]): log.info("GC duration {elapsed} ms".format_map(row)) log.info( " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" @@ -78,8 +78,7 @@ def etcd_path() -> Path: path_output = shutil.which("etcd") if path_output is None: raise RuntimeError("etcd not found in PATH") - else: - return Path(path_output) + return Path(path_output) def query_scalar(cur: cursor, query: str) -> Any: @@ -124,7 +123,6 @@ def get_timeline_dir_size(path: Path) -> int: # file is a delta layer _ = parse_delta_layer(dir_entry.name) sz += dir_entry.stat().st_size - continue return sz @@ -157,8 +155,8 @@ def get_scale_for_db(size_mb: int) -> int: return round(0.06689 * size_mb - 0.5) -ATTACHMENT_NAME_REGEX = re.compile( - r".+\.log|.+\.stderr|.+\.stdout|.+\.filediff|.+\.metrics|flamegraph\.svg|regression\.diffs|.+\.html" +ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] + r"flamegraph\.svg|regression\.diffs|.+\.(?:log|stderr|stdout|filediff|metrics|html)" ) diff --git a/test_runner/performance/test_wal_backpressure.py b/test_runner/performance/test_wal_backpressure.py index 47e2435052..cb35cad46b 100644 --- a/test_runner/performance/test_wal_backpressure.py +++ b/test_runner/performance/test_wal_backpressure.py @@ -2,7 +2,7 @@ import statistics import threading import time import timeit -from typing import Callable +from typing import Any, Callable, List import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -197,7 +197,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte if not isinstance(env, NeonCompare): return - lsn_write_lags = [] + lsn_write_lags: List[Any] = [] last_received_lsn = Lsn(0) last_pg_flush_lsn = Lsn(0) @@ -216,6 +216,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte ) res = cur.fetchone() + assert isinstance(res, list) lsn_write_lags.append(res[0]) curr_received_lsn = Lsn(res[3]) diff --git a/test_runner/pg_clients/python/asyncpg/asyncpg_example.py b/test_runner/pg_clients/python/asyncpg/asyncpg_example.py index 7f579ce672..4d9dfb09c1 100755 --- a/test_runner/pg_clients/python/asyncpg/asyncpg_example.py +++ b/test_runner/pg_clients/python/asyncpg/asyncpg_example.py @@ -24,7 +24,6 @@ if __name__ == "__main__": if (v := os.environ.get(k, None)) is not None } - loop = asyncio.new_event_loop() - row = loop.run_until_complete(run(**kwargs)) + row = asyncio.run(run(**kwargs)) print(row[0]) diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index bd02841dc0..b4647ebbe9 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -129,6 +129,7 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx create_and_send_db_info(vanilla_pg, psql_session_id, link_proxy.mgmt_port) + assert proc.stdout is not None out = (await proc.stdout.read()).decode("utf-8").strip() assert out == "42"