mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 18:02:56 +00:00
Stricter mypy linters for test_runner/fixtures/*
This commit is contained in:
committed by
Vadim Kharitonov
parent
c4f9f1dc6d
commit
f720dd735e
@@ -11,39 +11,37 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Type-related stuff
|
||||
from typing import Iterator, Optional
|
||||
from typing import Callable, ClassVar, Iterator, Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.terminal import TerminalReporter
|
||||
from fixtures.neon_fixtures import NeonPageserver
|
||||
from fixtures.types import TenantId, TimelineId
|
||||
|
||||
"""
|
||||
This file contains fixtures for micro-benchmarks.
|
||||
|
||||
To use, declare the 'zenbenchmark' fixture in the test function. Run the
|
||||
bencmark, and then record the result by calling zenbenchmark.record. For example:
|
||||
To use, declare the `zenbenchmark` fixture in the test function. Run the
|
||||
bencmark, and then record the result by calling `zenbenchmark.record`. For example:
|
||||
|
||||
import timeit
|
||||
from fixtures.neon_fixtures import NeonEnv
|
||||
|
||||
def test_mybench(neon_simple_env: env, zenbenchmark):
|
||||
|
||||
# Initialize the test
|
||||
...
|
||||
|
||||
# Run the test, timing how long it takes
|
||||
with zenbenchmark.record_duration('test_query'):
|
||||
cur.execute('SELECT test_query(...)')
|
||||
|
||||
# Record another measurement
|
||||
zenbenchmark.record('speed_of_light', 300000, 'km/s')
|
||||
>>> import timeit
|
||||
>>> from fixtures.neon_fixtures import NeonEnv
|
||||
>>> def test_mybench(neon_simple_env: NeonEnv, zenbenchmark):
|
||||
... # Initialize the test
|
||||
... ...
|
||||
... # Run the test, timing how long it takes
|
||||
... with zenbenchmark.record_duration('test_query'):
|
||||
... cur.execute('SELECT test_query(...)')
|
||||
... # Record another measurement
|
||||
... zenbenchmark.record('speed_of_light', 300000, 'km/s')
|
||||
|
||||
There's no need to import this file to use it. It should be declared as a plugin
|
||||
inside conftest.py, and that makes it available to all tests.
|
||||
inside `conftest.py`, and that makes it available to all tests.
|
||||
|
||||
You can measure multiple things in one test, and record each one with a separate
|
||||
call to zenbenchmark. For example, you could time the bulk loading that happens
|
||||
call to `zenbenchmark`. For example, you could time the bulk loading that happens
|
||||
in the test initialization, or measure disk usage after the test query.
|
||||
|
||||
"""
|
||||
@@ -117,7 +115,7 @@ class PgBenchRunResult:
|
||||
# tps = 309.281539 (without initial connection time)
|
||||
if line.startswith("tps = ") and (
|
||||
"(excluding connections establishing)" in line
|
||||
or "(without initial connection time)"
|
||||
or "(without initial connection time)" in line
|
||||
):
|
||||
tps = float(line.split()[2])
|
||||
|
||||
@@ -137,6 +135,17 @@ class PgBenchRunResult:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PgBenchInitResult:
|
||||
REGEX: ClassVar[re.Pattern] = re.compile( # type: ignore[type-arg]
|
||||
r"done in (\d+\.\d+) s "
|
||||
r"\("
|
||||
r"(?:drop tables (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:create tables (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:client-side generate (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:vacuum (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:primary keys (\d+\.\d+) s)?(?:, )?"
|
||||
r"\)\."
|
||||
)
|
||||
|
||||
total: float
|
||||
drop_tables: Optional[float]
|
||||
create_tables: Optional[float]
|
||||
@@ -160,18 +169,7 @@ class PgBenchInitResult:
|
||||
|
||||
last_line = stderr.splitlines()[-1]
|
||||
|
||||
regex = re.compile(
|
||||
r"done in (\d+\.\d+) s "
|
||||
r"\("
|
||||
r"(?:drop tables (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:create tables (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:client-side generate (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:vacuum (\d+\.\d+) s)?(?:, )?"
|
||||
r"(?:primary keys (\d+\.\d+) s)?(?:, )?"
|
||||
r"\)\."
|
||||
)
|
||||
|
||||
if (m := regex.match(last_line)) is not None:
|
||||
if (m := cls.REGEX.match(last_line)) is not None:
|
||||
total, drop_tables, create_tables, client_side_generate, vacuum, primary_keys = [
|
||||
float(v) for v in m.groups() if v is not None
|
||||
]
|
||||
@@ -208,7 +206,7 @@ class NeonBenchmarker:
|
||||
function by the zenbenchmark fixture
|
||||
"""
|
||||
|
||||
def __init__(self, property_recorder):
|
||||
def __init__(self, property_recorder: Callable[[str, object], None]):
|
||||
# property recorder here is a pytest fixture provided by junitxml module
|
||||
# https://docs.pytest.org/en/6.2.x/reference.html#pytest.junitxml.record_property
|
||||
self.property_recorder = property_recorder
|
||||
@@ -236,7 +234,7 @@ class NeonBenchmarker:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def record_duration(self, metric_name: str):
|
||||
def record_duration(self, metric_name: str) -> Iterator[None]:
|
||||
"""
|
||||
Record a duration. Usage:
|
||||
|
||||
@@ -337,21 +335,21 @@ class NeonBenchmarker:
|
||||
f"{prefix}.{metric}", value, unit="s", report=MetricReport.LOWER_IS_BETTER
|
||||
)
|
||||
|
||||
def get_io_writes(self, pageserver) -> int:
|
||||
def get_io_writes(self, pageserver: NeonPageserver) -> int:
|
||||
"""
|
||||
Fetch the "cumulative # of bytes written" metric from the pageserver
|
||||
"""
|
||||
metric_name = r'libmetrics_disk_io_bytes_total{io_operation="write"}'
|
||||
return self.get_int_counter_value(pageserver, metric_name)
|
||||
|
||||
def get_peak_mem(self, pageserver) -> int:
|
||||
def get_peak_mem(self, pageserver: NeonPageserver) -> int:
|
||||
"""
|
||||
Fetch the "maxrss" metric from the pageserver
|
||||
"""
|
||||
metric_name = r"libmetrics_maxrss_kb"
|
||||
return self.get_int_counter_value(pageserver, metric_name)
|
||||
|
||||
def get_int_counter_value(self, pageserver, metric_name) -> int:
|
||||
def get_int_counter_value(self, pageserver: NeonPageserver, metric_name: str) -> int:
|
||||
"""Fetch the value of given int counter from pageserver metrics."""
|
||||
# TODO: If we start to collect more of the prometheus metrics in the
|
||||
# performance test suite like this, we should refactor this to load and
|
||||
@@ -365,7 +363,9 @@ class NeonBenchmarker:
|
||||
assert matches, f"metric {metric_name} not found"
|
||||
return int(round(float(matches.group(1))))
|
||||
|
||||
def get_timeline_size(self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId):
|
||||
def get_timeline_size(
|
||||
self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the on-disk size of a timeline
|
||||
"""
|
||||
@@ -379,7 +379,9 @@ class NeonBenchmarker:
|
||||
return totalbytes
|
||||
|
||||
@contextmanager
|
||||
def record_pageserver_writes(self, pageserver, metric_name):
|
||||
def record_pageserver_writes(
|
||||
self, pageserver: NeonPageserver, metric_name: str
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Record bytes written by the pageserver during a test.
|
||||
"""
|
||||
@@ -396,7 +398,7 @@ class NeonBenchmarker:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]:
|
||||
def zenbenchmark(record_property: Callable[[str, object], None]) -> Iterator[NeonBenchmarker]:
|
||||
"""
|
||||
This is a python decorator for benchmark fixtures. It contains functions for
|
||||
recording measurements, and prints them out at the end.
|
||||
@@ -405,7 +407,7 @@ def zenbenchmark(record_property) -> Iterator[NeonBenchmarker]:
|
||||
yield benchmarker
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
def pytest_addoption(parser: Parser):
|
||||
parser.addoption(
|
||||
"--out-dir",
|
||||
dest="out_dir",
|
||||
@@ -429,7 +431,9 @@ def get_out_path(target_dir: Path, revision: str) -> Path:
|
||||
|
||||
# Hook to print the results at the end
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_terminal_summary(terminalreporter: TerminalReporter, exitstatus: int, config: Config):
|
||||
def pytest_terminal_summary(
|
||||
terminalreporter: TerminalReporter, exitstatus: int, config: Config
|
||||
) -> Iterator[None]:
|
||||
yield
|
||||
revision = os.getenv("GITHUB_SHA", "local")
|
||||
platform = os.getenv("PLATFORM", "local")
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from contextlib import _GeneratorContextManager, contextmanager
|
||||
|
||||
# Type-related stuff
|
||||
from typing import Dict, List
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
|
||||
from fixtures.neon_fixtures import NeonEnv, PgBin, PgProtocol, RemotePostgres, VanillaPostgres
|
||||
from fixtures.pg_stats import PgStatTable
|
||||
@@ -28,19 +29,20 @@ class PgCompare(ABC):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> None:
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def report_peak_memory_use(self) -> None:
|
||||
def report_peak_memory_use(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def report_size(self) -> None:
|
||||
def report_size(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@@ -54,7 +56,7 @@ class PgCompare(ABC):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def record_pg_stats(self, pg_stats: List[PgStatTable]):
|
||||
def record_pg_stats(self, pg_stats: List[PgStatTable]) -> Iterator[None]:
|
||||
init_data = self._retrieve_pg_stats(pg_stats)
|
||||
|
||||
yield
|
||||
@@ -84,7 +86,11 @@ class NeonCompare(PgCompare):
|
||||
"""PgCompare interface for the neon stack."""
|
||||
|
||||
def __init__(
|
||||
self, zenbenchmark: NeonBenchmarker, neon_simple_env: NeonEnv, pg_bin: PgBin, branch_name
|
||||
self,
|
||||
zenbenchmark: NeonBenchmarker,
|
||||
neon_simple_env: NeonEnv,
|
||||
pg_bin: PgBin,
|
||||
branch_name: str,
|
||||
):
|
||||
self.env = neon_simple_env
|
||||
self._zenbenchmark = zenbenchmark
|
||||
@@ -97,15 +103,15 @@ class NeonCompare(PgCompare):
|
||||
self.timeline = self.pg.safe_psql("SHOW neon.timeline_id")[0][0]
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg_bin
|
||||
|
||||
def flush(self):
|
||||
@@ -114,7 +120,7 @@ class NeonCompare(PgCompare):
|
||||
def compact(self):
|
||||
self.pageserver_http_client.timeline_compact(self.env.initial_tenant, self.timeline)
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
def report_peak_memory_use(self):
|
||||
self.zenbenchmark.record(
|
||||
"peak_mem",
|
||||
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
|
||||
@@ -122,7 +128,7 @@ class NeonCompare(PgCompare):
|
||||
report=MetricReport.LOWER_IS_BETTER,
|
||||
)
|
||||
|
||||
def report_size(self) -> None:
|
||||
def report_size(self):
|
||||
timeline_size = self.zenbenchmark.get_timeline_size(
|
||||
self.env.repo_dir, self.env.initial_tenant, self.timeline
|
||||
)
|
||||
@@ -144,17 +150,17 @@ class NeonCompare(PgCompare):
|
||||
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
|
||||
)
|
||||
|
||||
def record_pageserver_writes(self, out_name):
|
||||
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
|
||||
|
||||
def record_duration(self, out_name):
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
class VanillaCompare(PgCompare):
|
||||
"""PgCompare interface for vanilla postgres."""
|
||||
|
||||
def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres):
|
||||
def __init__(self, zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres):
|
||||
self._pg = vanilla_pg
|
||||
self._zenbenchmark = zenbenchmark
|
||||
vanilla_pg.configure(
|
||||
@@ -170,24 +176,24 @@ class VanillaCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
self.cur.execute("checkpoint")
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
def report_peak_memory_use(self):
|
||||
pass # TODO find something
|
||||
|
||||
def report_size(self) -> None:
|
||||
def report_size(self):
|
||||
data_size = self.pg.get_subdir_size("base")
|
||||
self.zenbenchmark.record(
|
||||
"data_size", data_size / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER
|
||||
@@ -198,17 +204,17 @@ class VanillaCompare(PgCompare):
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def record_pageserver_writes(self, out_name):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
def record_duration(self, out_name):
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
class RemoteCompare(PgCompare):
|
||||
"""PgCompare interface for a remote postgres instance."""
|
||||
|
||||
def __init__(self, zenbenchmark, remote_pg: RemotePostgres):
|
||||
def __init__(self, zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres):
|
||||
self._pg = remote_pg
|
||||
self._zenbenchmark = zenbenchmark
|
||||
|
||||
@@ -217,55 +223,60 @@ class RemoteCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
# TODO: flush the remote pageserver
|
||||
pass
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
def report_peak_memory_use(self):
|
||||
# TODO: get memory usage from remote pageserver
|
||||
pass
|
||||
|
||||
def report_size(self) -> None:
|
||||
def report_size(self):
|
||||
# TODO: get storage size from remote pageserver
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def record_pageserver_writes(self, out_name):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
def record_duration(self, out_name):
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_compare(request, zenbenchmark, pg_bin, neon_simple_env) -> NeonCompare:
|
||||
def neon_compare(
|
||||
request: FixtureRequest,
|
||||
zenbenchmark: NeonBenchmarker,
|
||||
pg_bin: PgBin,
|
||||
neon_simple_env: NeonEnv,
|
||||
) -> NeonCompare:
|
||||
branch_name = request.node.name
|
||||
return NeonCompare(zenbenchmark, neon_simple_env, pg_bin, branch_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
|
||||
def vanilla_compare(zenbenchmark: NeonBenchmarker, vanilla_pg: VanillaPostgres) -> VanillaCompare:
|
||||
return VanillaCompare(zenbenchmark, vanilla_pg)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare:
|
||||
def remote_compare(zenbenchmark: NeonBenchmarker, remote_pg: RemotePostgres) -> RemoteCompare:
|
||||
return RemoteCompare(zenbenchmark, remote_pg)
|
||||
|
||||
|
||||
@pytest.fixture(params=["vanilla_compare", "neon_compare"], ids=["vanilla", "neon"])
|
||||
def neon_with_baseline(request) -> PgCompare:
|
||||
def neon_with_baseline(request: FixtureRequest) -> PgCompare:
|
||||
"""Parameterized fixture that helps compare neon against vanilla postgres.
|
||||
|
||||
A test that uses this fixture turns into a parameterized test that runs against:
|
||||
@@ -286,8 +297,6 @@ def neon_with_baseline(request) -> PgCompare:
|
||||
implementation-specific logic is widely useful across multiple tests, it might
|
||||
make sense to add methods to the PgCompare class.
|
||||
"""
|
||||
fixture = request.getfixturevalue(request.param)
|
||||
if isinstance(fixture, PgCompare):
|
||||
return fixture
|
||||
else:
|
||||
raise AssertionError(f"test error: fixture {request.param} is not PgCompare")
|
||||
fixture = request.getfixturevalue(request.param) # type: ignore
|
||||
assert isinstance(fixture, PgCompare), f"test error: fixture {fixture} is not PgCompare"
|
||||
return fixture
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
from prometheus_client.samples import Sample
|
||||
@@ -23,13 +23,13 @@ class Metrics:
|
||||
pass
|
||||
return res
|
||||
|
||||
def query_one(self, name: str, filter: Dict[str, str] = {}) -> Sample:
|
||||
res = self.query_all(name, filter)
|
||||
def query_one(self, name: str, filter: Optional[Dict[str, str]] = None) -> Sample:
|
||||
res = self.query_all(name, filter or {})
|
||||
assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}"
|
||||
return res[0]
|
||||
|
||||
|
||||
def parse_metrics(text: str, name: str = ""):
|
||||
def parse_metrics(text: str, name: str = "") -> Metrics:
|
||||
metrics = Metrics(name)
|
||||
gen = text_string_to_metric_families(text)
|
||||
for family in gen:
|
||||
@@ -39,7 +39,7 @@ def parse_metrics(text: str, name: str = ""):
|
||||
return metrics
|
||||
|
||||
|
||||
PAGESERVER_PER_TENANT_METRICS = [
|
||||
PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = (
|
||||
"pageserver_current_logical_size",
|
||||
"pageserver_current_physical_size",
|
||||
"pageserver_getpage_reconstruct_seconds_bucket",
|
||||
@@ -62,4 +62,4 @@ PAGESERVER_PER_TENANT_METRICS = [
|
||||
"pageserver_wait_lsn_seconds_sum",
|
||||
"pageserver_created_persistent_files_total",
|
||||
"pageserver_written_persistent_bytes_total",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -19,7 +19,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Flag, auto
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import asyncpg
|
||||
import backoff # type: ignore
|
||||
@@ -28,16 +29,18 @@ import jwt
|
||||
import psycopg2
|
||||
import pytest
|
||||
import requests
|
||||
from _pytest.config import Config
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.types import Lsn, TenantId, TimelineId
|
||||
from fixtures.utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .utils import Fn, allure_attach_from_dir, etcd_path, get_self_dir, subprocess_capture
|
||||
|
||||
"""
|
||||
This file contains pytest fixtures. A fixture is a test resource that can be
|
||||
summoned by placing its name in the test's arguments.
|
||||
@@ -57,15 +60,15 @@ put directly-importable functions into utils.py or another separate file.
|
||||
|
||||
Env = Dict[str, str]
|
||||
|
||||
DEFAULT_OUTPUT_DIR = "test_output"
|
||||
DEFAULT_BRANCH_NAME = "main"
|
||||
DEFAULT_PG_VERSION_DEFAULT = "14"
|
||||
DEFAULT_OUTPUT_DIR: str = "test_output"
|
||||
DEFAULT_BRANCH_NAME: str = "main"
|
||||
DEFAULT_PG_VERSION_DEFAULT: str = "14"
|
||||
|
||||
BASE_PORT = 15000
|
||||
WORKER_PORT_NUM = 1000
|
||||
BASE_PORT: int = 15000
|
||||
WORKER_PORT_NUM: int = 1000
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
def pytest_configure(config: Config):
|
||||
"""
|
||||
Check that we do not overflow available ports range.
|
||||
"""
|
||||
@@ -154,14 +157,14 @@ def versioned_pg_distrib_dir(pg_distrib_dir: Path, pg_version: str) -> Iterator[
|
||||
if not psql_bin_path.exists():
|
||||
raise Exception(f"psql not found at '{psql_bin_path}'")
|
||||
else:
|
||||
if not postgres_bin_path.exists:
|
||||
if not postgres_bin_path.exists():
|
||||
raise Exception(f"postgres not found at '{postgres_bin_path}'")
|
||||
|
||||
log.info(f"versioned_pg_distrib_dir is {versioned_dir}")
|
||||
yield versioned_dir
|
||||
|
||||
|
||||
def shareable_scope(fixture_name, config) -> Literal["session", "function"]:
|
||||
def shareable_scope(fixture_name: str, config: Config) -> Literal["session", "function"]:
|
||||
"""Return either session of function scope, depending on TEST_SHARED_FIXTURES envvar.
|
||||
|
||||
This function can be used as a scope like this:
|
||||
@@ -173,7 +176,7 @@ def shareable_scope(fixture_name, config) -> Literal["session", "function"]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def worker_seq_no(worker_id: str):
|
||||
def worker_seq_no(worker_id: str) -> int:
|
||||
# worker_id is a pytest-xdist fixture
|
||||
# it can be master or gw<number>
|
||||
# parse it to always get a number
|
||||
@@ -184,7 +187,7 @@ def worker_seq_no(worker_id: str):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def worker_base_port(worker_seq_no: int):
|
||||
def worker_base_port(worker_seq_no: int) -> int:
|
||||
# so we divide ports in ranges of 100 ports
|
||||
# so workers have disjoint set of ports for services
|
||||
return BASE_PORT + worker_seq_no * WORKER_PORT_NUM
|
||||
@@ -234,10 +237,9 @@ class PortDistributor:
|
||||
for port in self.iterator:
|
||||
if can_bind("localhost", port):
|
||||
return port
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"port range configured for test is exhausted, consider enlarging the range"
|
||||
)
|
||||
raise RuntimeError(
|
||||
"port range configured for test is exhausted, consider enlarging the range"
|
||||
)
|
||||
|
||||
def replace_with_new_port(self, value: Union[int, str]) -> Union[int, str]:
|
||||
"""
|
||||
@@ -273,12 +275,14 @@ class PortDistributor:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def port_distributor(worker_base_port):
|
||||
def port_distributor(worker_base_port: int) -> PortDistributor:
|
||||
return PortDistributor(base_port=worker_base_port, port_number=WORKER_PORT_NUM)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_broker(request: Any, port_distributor: PortDistributor, top_output_dir: Path):
|
||||
def default_broker(
|
||||
request: FixtureRequest, port_distributor: PortDistributor, top_output_dir: Path
|
||||
) -> Iterator[Etcd]:
|
||||
client_port = port_distributor.get_port()
|
||||
# multiple pytest sessions could get launched in parallel, get them different datadirs
|
||||
etcd_datadir = get_test_output_dir(request, top_output_dir) / f"etcd_datadir_{client_port}"
|
||||
@@ -293,12 +297,12 @@ def default_broker(request: Any, port_distributor: PortDistributor, top_output_d
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def run_id():
|
||||
def run_id() -> Iterator[uuid.UUID]:
|
||||
yield uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_s3_server(port_distributor: PortDistributor):
|
||||
def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]:
|
||||
mock_s3_server = MockS3Server(port_distributor.get_port())
|
||||
yield mock_s3_server
|
||||
mock_s3_server.kill()
|
||||
@@ -307,16 +311,16 @@ def mock_s3_server(port_distributor: PortDistributor):
|
||||
class PgProtocol:
|
||||
"""Reusable connection logic"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.default_options = kwargs
|
||||
|
||||
def connstr(self, **kwargs) -> str:
|
||||
def connstr(self, **kwargs: Any) -> str:
|
||||
"""
|
||||
Build a libpq connection string for the Postgres instance.
|
||||
"""
|
||||
return str(make_dsn(**self.conn_options(**kwargs)))
|
||||
|
||||
def conn_options(self, **kwargs):
|
||||
def conn_options(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Construct a dictionary of connection options from default values and extra parameters.
|
||||
An option can be dropped from the returning dictionary by None-valued extra parameter.
|
||||
@@ -338,7 +342,7 @@ class PgProtocol:
|
||||
return result
|
||||
|
||||
# autocommit=True here by default because that's what we need most of the time
|
||||
def connect(self, autocommit=True, **kwargs) -> PgConnection:
|
||||
def connect(self, autocommit: bool = True, **kwargs: Any) -> PgConnection:
|
||||
"""
|
||||
Connect to the node.
|
||||
Returns psycopg2's connection object.
|
||||
@@ -351,7 +355,7 @@ class PgProtocol:
|
||||
return conn
|
||||
|
||||
@contextmanager
|
||||
def cursor(self, autocommit=True, **kwargs):
|
||||
def cursor(self, autocommit: bool = True, **kwargs: Any) -> Iterator[PgCursor]:
|
||||
"""
|
||||
Shorthand for pg.connect().cursor().
|
||||
The cursor and connection are closed when the context is exited.
|
||||
@@ -359,7 +363,7 @@ class PgProtocol:
|
||||
with closing(self.connect(autocommit=autocommit, **kwargs)) as conn:
|
||||
yield conn.cursor()
|
||||
|
||||
async def connect_async(self, **kwargs) -> asyncpg.Connection:
|
||||
async def connect_async(self, **kwargs: Any) -> asyncpg.Connection:
|
||||
"""
|
||||
Connect to the node from async python.
|
||||
Returns asyncpg's connection object.
|
||||
@@ -413,10 +417,10 @@ class PgProtocol:
|
||||
|
||||
@dataclass
|
||||
class AuthKeys:
|
||||
pub: bytes
|
||||
priv: bytes
|
||||
pub: str
|
||||
priv: str
|
||||
|
||||
def generate_management_token(self):
|
||||
def generate_management_token(self) -> str:
|
||||
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
|
||||
|
||||
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
|
||||
@@ -427,9 +431,11 @@ class AuthKeys:
|
||||
|
||||
return token
|
||||
|
||||
def generate_tenant_token(self, tenant_id):
|
||||
def generate_tenant_token(self, tenant_id: TenantId) -> str:
|
||||
token = jwt.encode(
|
||||
{"scope": "tenant", "tenant_id": str(tenant_id)}, self.priv, algorithm="RS256"
|
||||
{"scope": "tenant", "tenant_id": str(tenant_id)},
|
||||
self.priv,
|
||||
algorithm="RS256",
|
||||
)
|
||||
|
||||
if isinstance(token, bytes):
|
||||
@@ -485,7 +491,7 @@ class MockS3Server:
|
||||
|
||||
|
||||
@enum.unique
|
||||
class RemoteStorageKind(enum.Enum):
|
||||
class RemoteStorageKind(str, enum.Enum):
|
||||
LOCAL_FS = "local_fs"
|
||||
MOCK_S3 = "mock_s3"
|
||||
REAL_S3 = "real_s3"
|
||||
@@ -529,7 +535,7 @@ RemoteStorage = Union[LocalFsStorage, S3Storage]
|
||||
|
||||
|
||||
# serialize as toml inline table
|
||||
def remote_storage_to_toml_inline_table(remote_storage):
|
||||
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
|
||||
if isinstance(remote_storage, LocalFsStorage):
|
||||
remote_storage_config = f"local_path='{remote_storage.root}'"
|
||||
elif isinstance(remote_storage, S3Storage):
|
||||
@@ -582,7 +588,7 @@ class NeonEnvBuilder:
|
||||
safekeepers_enable_fsync: bool = False,
|
||||
auth_enabled: bool = False,
|
||||
rust_log_override: Optional[str] = None,
|
||||
default_branch_name=DEFAULT_BRANCH_NAME,
|
||||
default_branch_name: str = DEFAULT_BRANCH_NAME,
|
||||
):
|
||||
self.repo_dir = repo_dir
|
||||
self.rust_log_override = rust_log_override
|
||||
@@ -636,7 +642,7 @@ class NeonEnvBuilder:
|
||||
else:
|
||||
raise RuntimeError(f"Unknown storage type: {remote_storage_kind}")
|
||||
|
||||
def enable_local_fs_remote_storage(self, force_enable=True):
|
||||
def enable_local_fs_remote_storage(self, force_enable: bool = True):
|
||||
"""
|
||||
Sets up the pageserver to use the local fs at the `test_dir/local_fs_remote_storage` path.
|
||||
Errors, if the pageserver has some remote storage configuration already, unless `force_enable` is not set to `True`.
|
||||
@@ -644,7 +650,7 @@ class NeonEnvBuilder:
|
||||
assert force_enable or self.remote_storage is None, "remote storage is enabled already"
|
||||
self.remote_storage = LocalFsStorage(Path(self.repo_dir / "local_fs_remote_storage"))
|
||||
|
||||
def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable=True):
|
||||
def enable_mock_s3_remote_storage(self, bucket_name: str, force_enable: bool = True):
|
||||
"""
|
||||
Sets up the pageserver to use the S3 mock server, creates the bucket, if it's not present already.
|
||||
Starts up the mock server, if that does not run yet.
|
||||
@@ -671,7 +677,7 @@ class NeonEnvBuilder:
|
||||
secret_key=self.mock_s3_server.secret_key(),
|
||||
)
|
||||
|
||||
def enable_real_s3_remote_storage(self, test_name: str, force_enable=True):
|
||||
def enable_real_s3_remote_storage(self, test_name: str, force_enable: bool = True):
|
||||
"""
|
||||
Sets up configuration to use real s3 endpoint without mock server
|
||||
"""
|
||||
@@ -759,10 +765,15 @@ class NeonEnvBuilder:
|
||||
|
||||
log.info("deleted %s objects from remote storage", cnt)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "NeonEnvBuilder":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
):
|
||||
# Stop all the nodes.
|
||||
if self.env:
|
||||
log.info("Cleaning up all storage and compute nodes")
|
||||
@@ -909,7 +920,7 @@ class NeonEnv:
|
||||
|
||||
def get_safekeeper_connstrs(self) -> str:
|
||||
"""Get list of safekeeper endpoints suitable for safekeepers GUC"""
|
||||
return ",".join([f"localhost:{wa.port.pg}" for wa in self.safekeepers])
|
||||
return ",".join(f"localhost:{wa.port.pg}" for wa in self.safekeepers)
|
||||
|
||||
def timeline_dir(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
|
||||
"""Get a timeline directory's path based on the repo directory of the test environment"""
|
||||
@@ -928,14 +939,14 @@ class NeonEnv:
|
||||
|
||||
@cached_property
|
||||
def auth_keys(self) -> AuthKeys:
|
||||
pub = (Path(self.repo_dir) / "auth_public_key.pem").read_bytes()
|
||||
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_bytes()
|
||||
pub = (Path(self.repo_dir) / "auth_public_key.pem").read_text()
|
||||
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text()
|
||||
return AuthKeys(pub=pub, priv=priv)
|
||||
|
||||
|
||||
@pytest.fixture(scope=shareable_scope)
|
||||
def _shared_simple_env(
|
||||
request: Any,
|
||||
request: FixtureRequest,
|
||||
port_distributor: PortDistributor,
|
||||
mock_s3_server: MockS3Server,
|
||||
default_broker: Etcd,
|
||||
@@ -993,7 +1004,7 @@ def neon_simple_env(_shared_simple_env: NeonEnv) -> Iterator[NeonEnv]:
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_env_builder(
|
||||
test_output_dir,
|
||||
test_output_dir: str,
|
||||
port_distributor: PortDistributor,
|
||||
mock_s3_server: MockS3Server,
|
||||
neon_binpath: Path,
|
||||
@@ -1059,7 +1070,7 @@ class PageserverHttpClient(requests.Session):
|
||||
def check_status(self):
|
||||
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
|
||||
|
||||
def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]) -> None:
|
||||
def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]):
|
||||
self.is_testing_enabled_or_skip()
|
||||
|
||||
if isinstance(config_strings, tuple):
|
||||
@@ -1189,7 +1200,6 @@ class PageserverHttpClient(requests.Session):
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert res_json is None
|
||||
return res_json
|
||||
|
||||
def timeline_gc(
|
||||
self, tenant_id: TenantId, timeline_id: TimelineId, gc_horizon: Optional[int]
|
||||
@@ -1221,7 +1231,6 @@ class PageserverHttpClient(requests.Session):
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert res_json is None
|
||||
return res_json
|
||||
|
||||
def timeline_get_lsn_by_timestamp(
|
||||
self, tenant_id: TenantId, timeline_id: TimelineId, timestamp
|
||||
@@ -1247,7 +1256,6 @@ class PageserverHttpClient(requests.Session):
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert res_json is None
|
||||
return res_json
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
res = self.get(f"http://localhost:{self.port}/metrics")
|
||||
@@ -1261,13 +1269,10 @@ class PageserverPort:
|
||||
http: int
|
||||
|
||||
|
||||
CREATE_TIMELINE_ID_EXTRACTOR = re.compile(
|
||||
CREATE_TIMELINE_ID_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
r"^Created timeline '(?P<timeline_id>[^']+)'", re.MULTILINE
|
||||
)
|
||||
CREATE_TIMELINE_ID_EXTRACTOR = re.compile(
|
||||
r"^Created timeline '(?P<timeline_id>[^']+)'", re.MULTILINE
|
||||
)
|
||||
TIMELINE_DATA_EXTRACTOR = re.compile(
|
||||
TIMELINE_DATA_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
r"\s?(?P<branch_name>[^\s]+)\s\[(?P<timeline_id>[^\]]+)\]", re.MULTILINE
|
||||
)
|
||||
|
||||
@@ -1560,7 +1565,7 @@ class NeonCli(AbstractNeonCli):
|
||||
|
||||
def pageserver_start(
|
||||
self,
|
||||
overrides=(),
|
||||
overrides: Tuple[str, ...] = (),
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
start_args = ["pageserver", "start", *overrides]
|
||||
append_pageserver_param_overrides(
|
||||
@@ -1718,7 +1723,7 @@ class NeonPageserver(PgProtocol):
|
||||
self.config_override = config_override
|
||||
self.version = env.get_pageserver_version()
|
||||
|
||||
def start(self, overrides=()) -> "NeonPageserver":
|
||||
def start(self, overrides: Tuple[str, ...] = ()) -> "NeonPageserver":
|
||||
"""
|
||||
Start the page server.
|
||||
`overrides` allows to add some config to this pageserver start.
|
||||
@@ -1730,7 +1735,7 @@ class NeonPageserver(PgProtocol):
|
||||
self.running = True
|
||||
return self
|
||||
|
||||
def stop(self, immediate=False) -> "NeonPageserver":
|
||||
def stop(self, immediate: bool = False) -> "NeonPageserver":
|
||||
"""
|
||||
Stop the page server.
|
||||
Returns self.
|
||||
@@ -1740,10 +1745,15 @@ class NeonPageserver(PgProtocol):
|
||||
self.running = False
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "NeonPageserver":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
self.stop(immediate=True)
|
||||
|
||||
def is_testing_enabled_or_skip(self):
|
||||
@@ -1855,7 +1865,7 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: str) -> PgBi
|
||||
|
||||
|
||||
class VanillaPostgres(PgProtocol):
|
||||
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init=True):
|
||||
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True):
|
||||
super().__init__(host="localhost", port=port, dbname="postgres")
|
||||
self.pgdatadir = pgdatadir
|
||||
self.pg_bin = pg_bin
|
||||
@@ -1890,10 +1900,15 @@ class VanillaPostgres(PgProtocol):
|
||||
"""Return size of pgdatadir subdirectory in bytes."""
|
||||
return get_dir_size(os.path.join(self.pgdatadir, subdir))
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "VanillaPostgres":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
if self.running:
|
||||
self.stop()
|
||||
|
||||
@@ -1933,10 +1948,15 @@ class RemotePostgres(PgProtocol):
|
||||
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
|
||||
raise Exception("cannot get size of a Postgres instance")
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "RemotePostgres":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
# do nothing
|
||||
pass
|
||||
|
||||
@@ -1975,7 +1995,7 @@ class PSQL:
|
||||
self.path = path
|
||||
self.database_url = f"postgres://{host}:{port}/main?options=project%3Dgeneric-project-name"
|
||||
|
||||
async def run(self, query=None):
|
||||
async def run(self, query: Optional[str] = None) -> asyncio.subprocess.Process:
|
||||
run_args = [self.path, "--no-psqlrc", "--quiet", "--tuples-only", self.database_url]
|
||||
if query is not None:
|
||||
run_args += ["--command", query]
|
||||
@@ -2008,7 +2028,7 @@ class NeonProxy(PgProtocol):
|
||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||
self.link_auth_uri_prefix = "http://dummy-uri"
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self):
|
||||
"""
|
||||
Starts a proxy with option '--auth-backend postgres' and a postgres instance already provided though '--auth-endpoint <postgress-instance>'."
|
||||
"""
|
||||
@@ -2026,7 +2046,7 @@ class NeonProxy(PgProtocol):
|
||||
self._popen = subprocess.Popen(args)
|
||||
self._wait_until_ready()
|
||||
|
||||
def start_with_link_auth(self) -> None:
|
||||
def start_with_link_auth(self):
|
||||
"""
|
||||
Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'."
|
||||
"""
|
||||
@@ -2054,10 +2074,15 @@ class NeonProxy(PgProtocol):
|
||||
request_result.raise_for_status()
|
||||
return request_result.text
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "NeonProxy":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
if self._popen is not None:
|
||||
# NOTE the process will die when we're done with tests anyway, because
|
||||
# it's a child process. This is mostly to clean up in between different tests.
|
||||
@@ -2065,7 +2090,7 @@ class NeonProxy(PgProtocol):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
|
||||
def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes through link auth."""
|
||||
http_port = port_distributor.get_port()
|
||||
proxy_port = port_distributor.get_port()
|
||||
@@ -2076,7 +2101,9 @@ def link_proxy(port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def static_proxy(vanilla_pg, port_distributor, neon_binpath: Path) -> Iterator[NeonProxy]:
|
||||
def static_proxy(
|
||||
vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, neon_binpath: Path
|
||||
) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
|
||||
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
|
||||
@@ -2276,10 +2303,15 @@ class Postgres(PgProtocol):
|
||||
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "Postgres":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
self.stop()
|
||||
|
||||
|
||||
@@ -2288,7 +2320,7 @@ class PostgresFactory:
|
||||
|
||||
def __init__(self, env: NeonEnv):
|
||||
self.env = env
|
||||
self.num_instances = 0
|
||||
self.num_instances: int = 0
|
||||
self.instances: List[Postgres] = []
|
||||
|
||||
def create_start(
|
||||
@@ -2383,7 +2415,7 @@ class Safekeeper:
|
||||
break # success
|
||||
return self
|
||||
|
||||
def stop(self, immediate=False) -> "Safekeeper":
|
||||
def stop(self, immediate: bool = False) -> "Safekeeper":
|
||||
log.info("Stopping safekeeper {}".format(self.id))
|
||||
self.env.neon_cli.safekeeper_stop(self.id, immediate)
|
||||
self.running = False
|
||||
@@ -2598,7 +2630,7 @@ class Etcd:
|
||||
self.handle.wait()
|
||||
|
||||
|
||||
def get_test_output_dir(request: Any, top_output_dir: Path) -> Path:
|
||||
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""Compute the working directory for an individual test."""
|
||||
test_name = request.node.name
|
||||
test_dir = top_output_dir / test_name.replace("/", "-")
|
||||
@@ -2618,7 +2650,7 @@ def get_test_output_dir(request: Any, top_output_dir: Path) -> Path:
|
||||
# this fixture ensures that the directory exists. That works because
|
||||
# 'autouse' fixtures are run before other fixtures.
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def test_output_dir(request: Any, top_output_dir: Path) -> Iterator[Path]:
|
||||
def test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
|
||||
"""Create the working directory for an individual test."""
|
||||
|
||||
# one directory per test
|
||||
@@ -2682,7 +2714,7 @@ def should_skip_file(filename: str) -> bool:
|
||||
#
|
||||
# Test helpers
|
||||
#
|
||||
def list_files_to_compare(pgdata_dir: Path):
|
||||
def list_files_to_compare(pgdata_dir: Path) -> List[str]:
|
||||
pgdata_files = []
|
||||
for root, _file, filenames in os.walk(pgdata_dir):
|
||||
for filename in filenames:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from functools import cached_property
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@@ -13,7 +14,7 @@ class PgStatTable:
|
||||
self.columns = columns
|
||||
self.additional_query = filter_query
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def query(self) -> str:
|
||||
return f"SELECT {','.join(self.columns)} FROM {self.table} {self.additional_query}"
|
||||
|
||||
@@ -55,6 +56,5 @@ def pg_stats_wal() -> List[PgStatTable]:
|
||||
PgStatTable(
|
||||
"pg_stat_wal",
|
||||
["wal_records", "wal_fpi", "wal_bytes", "wal_buffers_full", "wal_write"],
|
||||
"",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
|
||||
"""
|
||||
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
|
||||
@@ -9,15 +13,15 @@ Copied from here: https://docs.pytest.org/en/latest/example/simple.html
|
||||
"""
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
def pytest_addoption(parser: Parser):
|
||||
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
def pytest_configure(config: Config):
|
||||
config.addinivalue_line("markers", "slow: mark test as slow to run")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
def pytest_collection_modifyitems(config: Config, items: List[Any]):
|
||||
if config.getoption("--runslow"):
|
||||
# --runslow given in cli: do not skip slow tests
|
||||
return
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import random
|
||||
from functools import total_ordering
|
||||
from typing import Union
|
||||
from typing import Any, Type, TypeVar, Union
|
||||
|
||||
T = TypeVar("T", bound="Id")
|
||||
|
||||
|
||||
@total_ordering
|
||||
@@ -17,31 +19,35 @@ class Lsn:
|
||||
"""Convert lsn from hex notation to int."""
|
||||
l, r = x.split("/")
|
||||
self.lsn_int = (int(l, 16) << 32) + int(r, 16)
|
||||
# FIXME: error if it doesn't look like a valid LSN
|
||||
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Convert lsn from int to standard hex notation."""
|
||||
return "{:X}/{:X}".format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF)
|
||||
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
|
||||
|
||||
def __repr__(self):
|
||||
return 'Lsn("{:X}/{:X}")'.format(self.lsn_int >> 32, self.lsn_int & 0xFFFFFFFF)
|
||||
def __repr__(self) -> str:
|
||||
return f'Lsn("{str(self)}")'
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.lsn_int
|
||||
|
||||
def __lt__(self, other: "Lsn") -> bool:
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int < other.lsn_int
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int == other.lsn_int
|
||||
|
||||
# Returns the difference between two Lsns, in bytes
|
||||
def __sub__(self, other: "Lsn") -> int:
|
||||
def __sub__(self, other: Any) -> int:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int - other.lsn_int
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.lsn_int)
|
||||
|
||||
|
||||
@@ -57,7 +63,7 @@ class Id:
|
||||
self.id = bytearray.fromhex(x)
|
||||
assert len(self.id) == 16
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
@@ -70,20 +76,20 @@ class Id:
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self.id))
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
def generate(cls: Type[T]) -> T:
|
||||
"""Generate a random ID"""
|
||||
return cls(random.randbytes(16).hex())
|
||||
|
||||
|
||||
class TenantId(Id):
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'`TenantId("{self.id.hex()}")'
|
||||
|
||||
|
||||
class TimelineId(Id):
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'TimelineId("{self.id.hex()}")'
|
||||
|
||||
@@ -6,7 +6,7 @@ import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Tuple, TypeVar
|
||||
|
||||
import allure # type: ignore
|
||||
from fixtures.log_helper import log
|
||||
@@ -30,11 +30,11 @@ def subprocess_capture(capture_dir: Path, cmd: List[str], **kwargs: Any) -> str:
|
||||
If those files already exist, we will overwrite them.
|
||||
Returns basepath for files with captured output.
|
||||
"""
|
||||
assert type(cmd) is list
|
||||
base = os.path.basename(cmd[0]) + "_{}".format(global_counter())
|
||||
assert isinstance(cmd, list)
|
||||
base = f"{os.path.basename(cmd[0])}_{global_counter()}"
|
||||
basepath = os.path.join(capture_dir, base)
|
||||
stdout_filename = basepath + ".stdout"
|
||||
stderr_filename = basepath + ".stderr"
|
||||
stdout_filename = f"{basepath}.stdout"
|
||||
stderr_filename = f"{basepath}.stderr"
|
||||
|
||||
try:
|
||||
with open(stdout_filename, "w") as stdout_f:
|
||||
@@ -64,7 +64,7 @@ def global_counter() -> int:
|
||||
return _global_counter
|
||||
|
||||
|
||||
def print_gc_result(row):
|
||||
def print_gc_result(row: Dict[str, Any]):
|
||||
log.info("GC duration {elapsed} ms".format_map(row))
|
||||
log.info(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
@@ -78,8 +78,7 @@ def etcd_path() -> Path:
|
||||
path_output = shutil.which("etcd")
|
||||
if path_output is None:
|
||||
raise RuntimeError("etcd not found in PATH")
|
||||
else:
|
||||
return Path(path_output)
|
||||
return Path(path_output)
|
||||
|
||||
|
||||
def query_scalar(cur: cursor, query: str) -> Any:
|
||||
@@ -124,7 +123,6 @@ def get_timeline_dir_size(path: Path) -> int:
|
||||
# file is a delta layer
|
||||
_ = parse_delta_layer(dir_entry.name)
|
||||
sz += dir_entry.stat().st_size
|
||||
continue
|
||||
return sz
|
||||
|
||||
|
||||
@@ -157,8 +155,8 @@ def get_scale_for_db(size_mb: int) -> int:
|
||||
return round(0.06689 * size_mb - 0.5)
|
||||
|
||||
|
||||
ATTACHMENT_NAME_REGEX = re.compile(
|
||||
r".+\.log|.+\.stderr|.+\.stdout|.+\.filediff|.+\.metrics|flamegraph\.svg|regression\.diffs|.+\.html"
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
r"flamegraph\.svg|regression\.diffs|.+\.(?:log|stderr|stdout|filediff|metrics|html)"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import statistics
|
||||
import threading
|
||||
import time
|
||||
import timeit
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import pytest
|
||||
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
|
||||
@@ -197,7 +197,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte
|
||||
if not isinstance(env, NeonCompare):
|
||||
return
|
||||
|
||||
lsn_write_lags = []
|
||||
lsn_write_lags: List[Any] = []
|
||||
last_received_lsn = Lsn(0)
|
||||
last_pg_flush_lsn = Lsn(0)
|
||||
|
||||
@@ -216,6 +216,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte
|
||||
)
|
||||
|
||||
res = cur.fetchone()
|
||||
assert isinstance(res, list)
|
||||
lsn_write_lags.append(res[0])
|
||||
|
||||
curr_received_lsn = Lsn(res[3])
|
||||
|
||||
@@ -24,7 +24,6 @@ if __name__ == "__main__":
|
||||
if (v := os.environ.get(k, None)) is not None
|
||||
}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
row = loop.run_until_complete(run(**kwargs))
|
||||
row = asyncio.run(run(**kwargs))
|
||||
|
||||
print(row[0])
|
||||
|
||||
@@ -129,6 +129,7 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx
|
||||
|
||||
create_and_send_db_info(vanilla_pg, psql_session_id, link_proxy.mgmt_port)
|
||||
|
||||
assert proc.stdout is not None
|
||||
out = (await proc.stdout.read()).decode("utf-8").strip()
|
||||
assert out == "42"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user