From 878135fe9cabbed4f1a3e15633bffb5ead61f31e Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Wed, 9 Oct 2024 14:02:09 -0500 Subject: [PATCH 01/18] Move PgBenchInitResult.EXTRACTORS to a private module constant This seems to paper over a behavioral difference in Python 3.9 and Python 3.12 with how dataclasses work with mutable variables. On Python 3.12, I get the following error: ValueError: mutable default for field EXTRACTORS is not allowed: use default_factory This obviously doesn't occur in our testing environment. When I do what the error tells me, EXTRACTORS doesn't seem to exist as an attribute on the class in at least Python 3.9. The solution provided in this commit seems like the least amount of friction to keep the wheels turning. Signed-off-by: Tristan Partin --- test_runner/fixtures/benchmark_fixture.py | 42 ++++++++++++++--------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 88f9ec1cd0..74fe39ef53 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -7,7 +7,6 @@ import json import os import re import timeit -from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime from pathlib import Path @@ -25,7 +24,8 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import NeonPageserver if TYPE_CHECKING: - from typing import Callable, ClassVar, Optional + from collections.abc import Iterator, Mapping + from typing import Callable, Optional """ @@ -141,6 +141,28 @@ class PgBenchRunResult: ) +# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 +# +# This used to be a class variable on PgBenchInitResult. However later versions +# of Python complain: +# +# ValueError: mutable default for field EXTRACTORS is not allowed: use default_factory +# +# When you do what the error tells you to do, it seems to fail our Python 3.9 +# test environment. So let's just move it to a private module constant, and move +# on. +_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = { + "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), + "create_tables": re.compile(r"create tables (\d+\.\d+) s"), + "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), + "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), + "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), + "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), + "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), + "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench +} + + @dataclasses.dataclass class PgBenchInitResult: total: Optional[float] @@ -155,20 +177,6 @@ class PgBenchInitResult: start_timestamp: int end_timestamp: int - # Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 - EXTRACTORS: ClassVar[dict[str, re.Pattern[str]]] = dataclasses.field( - default_factory=lambda: { - "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), - "create_tables": re.compile(r"create tables (\d+\.\d+) s"), - "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), - "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), - "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), - "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), - "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), - "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench - } - ) - @classmethod def parse_from_stderr( cls, @@ -185,7 +193,7 @@ class PgBenchInitResult: timings: dict[str, Optional[float]] = {} last_line_items = re.split(r"\(|\)|,", last_line) for item in last_line_items: - for key, regex in cls.EXTRACTORS.items(): + for key, regex in _PGBENCH_INIT_EXTRACTORS.items(): if (m := regex.match(item.strip())) is not None: if key in timings: raise RuntimeError( From d3464584a6fb03b9df264f32c18963388808ba2e Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Wed, 9 Oct 2024 15:42:22 -0500 Subject: [PATCH 02/18] Improve some typing in test_runner Fixes some types, adds some types, and adds some override annotations. Signed-off-by: Tristan Partin --- test_runner/fixtures/common_types.py | 37 +++++++++--- test_runner/fixtures/compare_fixtures.py | 36 ++++++++++- test_runner/fixtures/compute_reconfigure.py | 20 ++++--- test_runner/fixtures/flaky.py | 2 + test_runner/fixtures/httpserver.py | 13 +++- test_runner/fixtures/log_helper.py | 2 +- test_runner/fixtures/metrics.py | 4 +- test_runner/fixtures/neon_api.py | 12 ++-- test_runner/fixtures/neon_fixtures.py | 6 +- test_runner/fixtures/overlayfs.py | 5 +- test_runner/fixtures/parametrize.py | 7 ++- test_runner/fixtures/pg_version.py | 12 +++- test_runner/fixtures/port_distributor.py | 7 +-- test_runner/fixtures/remote_storage.py | 24 ++++---- .../fixtures/storage_controller_proxy.py | 9 ++- test_runner/fixtures/utils.py | 59 ++++++++++--------- test_runner/fixtures/workload.py | 11 +++- test_runner/regress/test_compaction.py | 6 +- .../regress/test_pageserver_generations.py | 6 +- test_runner/regress/test_sharding.py | 33 +++++++++-- .../regress/test_storage_controller.py | 2 +- test_runner/regress/test_storage_scrubber.py | 5 +- 22 files changed, 216 insertions(+), 102 deletions(-) diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index 3022c0279f..0ea7148f50 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -6,6 +6,8 @@ from enum import Enum from functools import total_ordering from typing import TYPE_CHECKING, TypeVar +from typing_extensions import override + if TYPE_CHECKING: from typing import Any, Union @@ -31,33 +33,36 @@ class Lsn: self.lsn_int = (int(left, 16) << 32) + int(right, 16) assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF + @override def __str__(self) -> str: """Convert lsn from int to standard hex notation.""" return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}" + @override def __repr__(self) -> str: return f'Lsn("{str(self)}")' def __int__(self) -> int: return self.lsn_int - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int < other.lsn_int - def __gt__(self, other: Any) -> bool: + def __gt__(self, other: object) -> bool: if not isinstance(other, Lsn): raise NotImplementedError return self.lsn_int > other.lsn_int - def __eq__(self, other: Any) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int == other.lsn_int # Returns the difference between two Lsns, in bytes - def __sub__(self, other: Any) -> int: + def __sub__(self, other: object) -> int: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int - other.lsn_int @@ -70,6 +75,7 @@ class Lsn: else: raise NotImplementedError + @override def __hash__(self) -> int: return hash(self.lsn_int) @@ -116,19 +122,22 @@ class Id: self.id = bytearray.fromhex(x) assert len(self.id) == 16 + @override def __str__(self) -> str: return self.id.hex() - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id < other.id - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id == other.id + @override def __hash__(self) -> int: return hash(str(self.id)) @@ -139,25 +148,31 @@ class Id: class TenantId(Id): + @override def __repr__(self) -> str: return f'`TenantId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class NodeId(Id): + @override def __repr__(self) -> str: return f'`NodeId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class TimelineId(Id): + @override def __repr__(self) -> str: return f'TimelineId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() @@ -187,7 +202,7 @@ class TenantShardId: assert self.shard_number < self.shard_count or self.shard_count == 0 @classmethod - def parse(cls: type[TTenantShardId], input) -> TTenantShardId: + def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId: if len(input) == 32: return cls( tenant_id=TenantId(input), @@ -203,6 +218,7 @@ class TenantShardId: else: raise ValueError(f"Invalid TenantShardId '{input}'") + @override def __str__(self): if self.shard_count > 0: return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}" @@ -210,22 +226,25 @@ class TenantShardId: # Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id) return str(self.tenant_id) + @override def __repr__(self): return self.__str__() def _tuple(self) -> tuple[TenantId, int, int]: return (self.tenant_id, self.shard_number, self.shard_count) - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() < other._tuple() - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() == other._tuple() + @override def __hash__(self) -> int: return hash(self._tuple()) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index ce191ac91c..2195ae8225 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager # Type-related stuff from pathlib import Path +from typing import TYPE_CHECKING import pytest from _pytest.fixtures import FixtureRequest +from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log @@ -24,6 +26,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pg_stats import PgStatTable +if TYPE_CHECKING: + from collections.abc import Iterator + class PgCompare(ABC): """Common interface of all postgres implementations, useful for benchmarks. @@ -65,12 +70,12 @@ class PgCompare(ABC): @contextmanager @abstractmethod - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str): pass @contextmanager @abstractmethod - def record_duration(self, out_name): + def record_duration(self, out_name: str): pass @contextmanager @@ -122,28 +127,34 @@ class NeonCompare(PgCompare): self._pg = self.env.endpoints.create_start("main", "main", self.tenant) @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg_bin + @override def flush(self, compact: bool = True, gc: bool = True): wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline) self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact) if gc: self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0) + @override def compact(self): self.pageserver_http_client.timeline_compact( self.tenant, self.timeline, wait_until_uploaded=True ) + @override def report_peak_memory_use(self): self.zenbenchmark.record( "peak_mem", @@ -152,6 +163,7 @@ class NeonCompare(PgCompare): report=MetricReport.LOWER_IS_BETTER, ) + @override def report_size(self): timeline_size = self.zenbenchmark.get_timeline_size( self.env.repo_dir, self.tenant, self.timeline @@ -185,9 +197,11 @@ class NeonCompare(PgCompare): "num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER ) + @override def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -211,26 +225,33 @@ class VanillaCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> VanillaPostgres: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin + @override def flush(self, compact: bool = False, gc: bool = False): self.cur.execute("checkpoint") + @override def compact(self): pass + @override def report_peak_memory_use(self): pass # TODO find something + @override def report_size(self): data_size = self.pg.get_subdir_size(Path("base")) self.zenbenchmark.record( @@ -245,6 +266,7 @@ class VanillaCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -261,28 +283,35 @@ class RemoteCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin - def flush(self): + @override + def flush(self, compact: bool = False, gc: bool = False): # TODO: flush the remote pageserver pass + @override def compact(self): pass + @override def report_peak_memory_use(self): # TODO: get memory usage from remote pageserver pass + @override def report_size(self): # TODO: get storage size from remote pageserver pass @@ -291,6 +320,7 @@ class RemoteCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) diff --git a/test_runner/fixtures/compute_reconfigure.py b/test_runner/fixtures/compute_reconfigure.py index d2305ea431..6354b7f833 100644 --- a/test_runner/fixtures/compute_reconfigure.py +++ b/test_runner/fixtures/compute_reconfigure.py @@ -1,27 +1,31 @@ from __future__ import annotations import concurrent.futures -from typing import Any +from typing import TYPE_CHECKING import pytest +from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response from fixtures.common_types import TenantId from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Callable, Optional + class ComputeReconfigure: - def __init__(self, server): + def __init__(self, server: HTTPServer): self.server = server self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach" - self.workloads = {} - self.on_notify = None + self.workloads: dict[TenantId, Any] = {} + self.on_notify: Optional[Callable[[Any], None]] = None - def register_workload(self, workload): + def register_workload(self, workload: Any): self.workloads[workload.tenant_id] = workload - def register_on_notify(self, fn): + def register_on_notify(self, fn: Optional[Callable[[Any], None]]): """ Add some extra work during a notification, like sleeping to slow things down, or logging what was notified. @@ -30,7 +34,7 @@ class ComputeReconfigure: @pytest.fixture(scope="function") -def compute_reconfigure_listener(make_httpserver): +def compute_reconfigure_listener(make_httpserver: HTTPServer): """ This fixture exposes an HTTP listener for the storage controller to submit compute notifications to us, instead of updating neon_local endpoints itself. @@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver): # accept a healthy rate of calls into notify-attach. reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1) - def handler(request: Request): + def handler(request: Request) -> Response: assert request.json is not None body: dict[str, Any] = request.json log.info(f"notify-attach request: {body}") diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py index 4ca87520a0..01634a29c5 100644 --- a/test_runner/fixtures/flaky.py +++ b/test_runner/fixtures/flaky.py @@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels from fixtures.log_helper import log if TYPE_CHECKING: + from collections.abc import MutableMapping from typing import Any + """ The plugin reruns flaky tests. It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` diff --git a/test_runner/fixtures/httpserver.py b/test_runner/fixtures/httpserver.py index 9d5b5d6422..f653fd804c 100644 --- a/test_runner/fixtures/httpserver.py +++ b/test_runner/fixtures/httpserver.py @@ -1,8 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from pytest_httpserver import HTTPServer +if TYPE_CHECKING: + from collections.abc import Iterator + + from fixtures.port_distributor import PortDistributor + # TODO: mypy fails with: # Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined] # from fixtures.neon_fixtures import PortDistributor @@ -17,7 +24,7 @@ def httpserver_ssl_context(): @pytest.fixture(scope="function") -def make_httpserver(httpserver_listen_address, httpserver_ssl_context): +def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]: host, port = httpserver_listen_address if not host: host = HTTPServer.DEFAULT_LISTEN_HOST @@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context): @pytest.fixture(scope="function") -def httpserver(make_httpserver): +def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]: server = make_httpserver yield server server.clear() @pytest.fixture(scope="function") -def httpserver_listen_address(port_distributor) -> tuple[str, int]: +def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]: port = port_distributor.get_port() return ("localhost", port) diff --git a/test_runner/fixtures/log_helper.py b/test_runner/fixtures/log_helper.py index 70d76a39c4..ebf5c8d803 100644 --- a/test_runner/fixtures/log_helper.py +++ b/test_runner/fixtures/log_helper.py @@ -31,7 +31,7 @@ LOGGING = { } -def getLogger(name="root") -> logging.Logger: +def getLogger(name: str = "root") -> logging.Logger: """Method to get logger for tests. Should be used to get correctly initialized logger.""" diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index adc90a41d0..e056ea77d4 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -22,7 +22,7 @@ class Metrics: def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]: filter = filter or {} - res = [] + res: list[Sample] = [] for sample in self.metrics[name]: try: @@ -59,7 +59,7 @@ class MetricsGetter: return results[0].value def get_metrics_values( - self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False + self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False ) -> dict[str, float]: """ When fetching multiple named metrics, it is more efficient to use this diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 846a790f1f..683ea3af44 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast import requests if TYPE_CHECKING: - from typing import Any, Literal, Optional, Union + from typing import Any, Literal, Optional from fixtures.pg_version import PgVersion @@ -25,9 +25,7 @@ class NeonAPI: self.__neon_api_key = neon_api_key self.__neon_api_base_url = neon_api_base_url.strip("/") - def __request( - self, method: Union[str, bytes], endpoint: str, **kwargs: Any - ) -> requests.Response: + def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response: if "headers" not in kwargs: kwargs["headers"] = {} kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}" @@ -187,8 +185,8 @@ class NeonAPI: def get_connection_uri( self, project_id: str, - branch_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + branch_id: str | None = None, + endpoint_id: str | None = None, database_name: str = "neondb", role_name: str = "neondb_owner", pooled: bool = True, @@ -264,7 +262,7 @@ class NeonAPI: class NeonApiEndpoint: - def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]): + def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None): self.neon_api = neon_api if project_id is None: project = neon_api.create_project(pg_version) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 5cb9821476..f81bc3f5a6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3657,7 +3657,7 @@ class Endpoint(PgProtocol, LogUtils): config_lines: Optional[list[str]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, - allow_multiple=False, + allow_multiple: bool = False, basebackup_request_tries: Optional[int] = None, ) -> Endpoint: """ @@ -3998,7 +3998,7 @@ class Safekeeper(LogUtils): def timeline_dir(self, tenant_id, timeline_id) -> Path: return self.data_dir / str(tenant_id) / str(timeline_id) - # List partial uploaded segments of this safekeeper. Works only for + # list partial uploaded segments of this safekeeper. Works only for # RemoteStorageKind.LOCAL_FS. def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId): tline_path = ( @@ -4293,7 +4293,7 @@ def pytest_addoption(parser: Parser): ) -SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile( r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)" ) diff --git a/test_runner/fixtures/overlayfs.py b/test_runner/fixtures/overlayfs.py index e0ebfeb8f4..ea11cd272c 100644 --- a/test_runner/fixtures/overlayfs.py +++ b/test_runner/fixtures/overlayfs.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Iterator from pathlib import Path +from typing import TYPE_CHECKING import psutil +if TYPE_CHECKING: + from collections.abc import Iterator + def iter_mounts_beneath(topdir: Path) -> Iterator[Path]: """ diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 3bbac4b8ee..4114c2fcb3 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -9,7 +9,12 @@ import toml from _pytest.python import Metafunc from fixtures.pg_version import PgVersion -from fixtures.utils import AuxFileStore + +if TYPE_CHECKING: + from typing import Any, Optional + + from fixtures.utils import AuxFileStore + if TYPE_CHECKING: from typing import Any, Optional diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 5820b50a46..01f0245665 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -2,9 +2,14 @@ from __future__ import annotations import enum import os -from typing import Optional +from typing import TYPE_CHECKING import pytest +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Optional + """ This fixture is used to determine which version of Postgres to use for tests. @@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum): NOT_SET = "<-POSTRGRES VERSION IS NOT SET->" # Make it less confusing in logs + @override def __repr__(self) -> str: return f"'{self.value}'" # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums + @override def __str__(self) -> str: return self.value @@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum): return f"v{self.value}" @classmethod - def _missing_(cls, value) -> Optional[PgVersion]: + @override + def _missing_(cls, value: object) -> Optional[PgVersion]: known_values = {v.value for _, v in cls.__members__.items()} # Allow passing version as a string with "v" prefix (e.g. "v14") diff --git a/test_runner/fixtures/port_distributor.py b/test_runner/fixtures/port_distributor.py index 435f452a02..df0eb2a809 100644 --- a/test_runner/fixtures/port_distributor.py +++ b/test_runner/fixtures/port_distributor.py @@ -59,10 +59,7 @@ class PortDistributor: if isinstance(value, int): return self._replace_port_int(value) - if isinstance(value, str): - return self._replace_port_str(value) - - raise TypeError(f"unsupported type {type(value)} of {value=}") + return self._replace_port_str(value) def _replace_port_int(self, value: int) -> int: known_port = self.port_map.get(value) @@ -75,7 +72,7 @@ class PortDistributor: # Use regex to find port in a string # urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432" # See https://bugs.python.org/issue27657 - ports = re.findall(r":(\d+)(?:/|$)", value) + ports: list[str] = re.findall(r":(\d+)(?:/|$)", value) assert len(ports) == 1, f"can't find port in {value}" port_int = int(ports[0]) diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index 20e6bd9318..7024953661 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -13,6 +13,7 @@ import boto3 import toml from moto.server import ThreadedMotoServer from mypy_boto3_s3 import S3Client +from typing_extensions import override from fixtures.common_types import TenantId, TenantShardId, TimelineId from fixtures.log_helper import log @@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum): EXTENSIONS = "ext" SAFEKEEPER = "safekeeper" + @override def __str__(self) -> str: return self.value @@ -81,11 +83,13 @@ class LocalFsStorage: def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path: return self.tenant_path(tenant_id) / "timelines" / str(timeline_id) - def timeline_latest_generation(self, tenant_id, timeline_id): + def timeline_latest_generation( + self, tenant_id: TenantId, timeline_id: TimelineId + ) -> Optional[int]: timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id)) index_parts = [f for f in timeline_files if f.startswith("index_part")] - def parse_gen(filename): + def parse_gen(filename: str) -> Optional[int]: log.info(f"parsing index_part '{filename}'") parts = filename.split("-") if len(parts) == 2: @@ -93,7 +97,7 @@ class LocalFsStorage: else: return None - generations = sorted([parse_gen(f) for f in index_parts]) + generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore if len(generations) == 0: raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}") return generations[-1] @@ -122,14 +126,14 @@ class LocalFsStorage: filename = f"{local_name}-{generation:08x}" return self.timeline_path(tenant_id, timeline_id) / filename - def index_content(self, tenant_id: TenantId, timeline_id: TimelineId): + def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any: with self.index_path(tenant_id, timeline_id).open("r") as f: return json.load(f) def heatmap_path(self, tenant_id: TenantId) -> Path: return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME - def heatmap_content(self, tenant_id): + def heatmap_content(self, tenant_id: TenantId) -> Any: with self.heatmap_path(tenant_id).open("r") as f: return json.load(f) @@ -297,7 +301,7 @@ class S3Storage: def heatmap_key(self, tenant_id: TenantId) -> str: return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}" - def heatmap_content(self, tenant_id: TenantId): + def heatmap_content(self, tenant_id: TenantId) -> Any: r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id)) return json.loads(r["Body"].read().decode("utf-8")) @@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum): def configure( self, repo_dir: Path, - mock_s3_server, + mock_s3_server: MockS3Server, run_id: str, test_name: str, user: RemoteStorageUser, @@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind: def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_dict() # serialize as toml inline table def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_inline_table() diff --git a/test_runner/fixtures/storage_controller_proxy.py b/test_runner/fixtures/storage_controller_proxy.py index 02cf6fc33f..c174358ef5 100644 --- a/test_runner/fixtures/storage_controller_proxy.py +++ b/test_runner/fixtures/storage_controller_proxy.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Optional +from typing import TYPE_CHECKING import pytest import requests @@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Optional + class StorageControllerProxy: def __init__(self, server: HTTPServer): @@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response: @pytest.fixture(scope="function") -def storage_controller_proxy(make_httpserver): +def storage_controller_proxy(make_httpserver: HTTPServer): """ Proxies requests into the storage controller to the currently selected storage controller instance via `StorageControllerProxy.route_to`. @@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver): log.info(f"Storage controller proxy listening on {self.listen}") - def handler(request: Request): + def handler(request: Request) -> Response: if self.route_to is None: log.info(f"Storage controller proxy has no routing configured for {request.url}") return Response("Routing not configured", status=503) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 23381e258a..ca1be35880 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -18,6 +18,7 @@ from urllib.parse import urlencode import allure import zstandard from psycopg2.extensions import cursor +from typing_extensions import override from fixtures.log_helper import log from fixtures.pageserver.common_types import ( @@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import ( ) if TYPE_CHECKING: - from typing import ( - IO, - Optional, - Union, - ) + from collections.abc import Iterable + from typing import IO, Optional + from fixtures.common_types import TimelineId from fixtures.neon_fixtures import PgBin -from fixtures.common_types import TimelineId + + WaitUntilRet = TypeVar("WaitUntilRet") + Fn = TypeVar("Fn", bound=Callable[..., Any]) @@ -42,12 +43,12 @@ def subprocess_capture( capture_dir: Path, cmd: list[str], *, - check=False, - echo_stderr=False, - echo_stdout=False, - capture_stdout=False, - timeout=None, - with_command_header=True, + check: bool = False, + echo_stderr: bool = False, + echo_stdout: bool = False, + capture_stdout: bool = False, + timeout: Optional[float] = None, + with_command_header: bool = True, **popen_kwargs: Any, ) -> tuple[str, Optional[str], int]: """Run a process and bifurcate its output to files and the `log` logger @@ -84,6 +85,7 @@ def subprocess_capture( self.capture = capture self.captured = "" + @override def run(self): first = with_command_header for line in self.in_file: @@ -165,10 +167,10 @@ def global_counter() -> int: def print_gc_result(row: dict[str, Any]): log.info("GC duration {elapsed} ms".format_map(row)) log.info( - " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" - " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map( - row - ) + ( + " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" + " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}" + ).format_map(row) ) @@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int: return round(0.06689 * size_mb - 0.5) -ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile( r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)" ) @@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz" def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int): """Add links to server logs in Grafana to Allure report""" - links = {} + links: dict[str, str] = {} # We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build endpoint_id, region_id, _ = host.split(".", 2) @@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, def start_in_background( - command: list[str], cwd: Path, log_file_name: str, is_started: Fn + command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet] ) -> subprocess.Popen[bytes]: """Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors.""" @@ -376,14 +378,11 @@ def start_in_background( return spawned_process -WaitUntilRet = TypeVar("WaitUntilRet") - - def wait_until( number_of_iterations: int, interval: float, func: Callable[[], WaitUntilRet], - show_intermediate_error=False, + show_intermediate_error: bool = False, ) -> WaitUntilRet: """ Wait until 'func' returns successfully, without exception. Returns the @@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float: def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]: # FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py error_or_warn = re.compile(r"\s(ERROR|WARN)") - errors = [] + errors: list[tuple[int, str]] = [] for lineno, line in enumerate(input, start=1): if len(line) == 0: continue @@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list return errors -def assert_no_errors(log_file, service, allowed_errors): +def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]): if not log_file.exists(): log.warning(f"Skipping {service} log check: {log_file} does not exist") return @@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum): V2 = "v2" CrossValidation = "cross-validation" + @override def __repr__(self) -> str: return f"'aux-{self.value}'" + @override def __str__(self) -> str: return f"'aux-{self.value}'" @@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str """ started_at = time.time() - def hash_extracted(reader: Union[IO[bytes], None]) -> bytes: + def hash_extracted(reader: Optional[IO[bytes]]) -> bytes: assert reader is not None digest = sha256(usedforsecurity=False) while True: @@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str right_list ), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}" - mismatching = set() + mismatching: set[str] = set() for left_tuple, right_tuple in zip(left_list, right_list): left_path, left_hash = left_tuple @@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread): Simple Thread wrapper with join() propagating the possible exception in the thread. """ + @override def run(self): self.exc = None try: @@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread): except BaseException as e: self.exc = e - def join(self, timeout=None): + @override + def join(self, timeout: Optional[float] = None) -> Any: super().join(timeout) if self.exc: raise self.exc diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index 4f9c1125bf..e869c43185 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from typing import Any, Optional +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log @@ -14,6 +14,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_for_last_record_lsn +if TYPE_CHECKING: + from typing import Any, Optional + # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. ENDPOINT_LOCK = threading.Lock() @@ -100,7 +103,7 @@ class Workload: self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id ) - def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True): + def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True): endpoint = self.endpoint(pageserver_id) start = self.expect_rows end = start + n - 1 @@ -121,7 +124,9 @@ class Workload: else: return False - def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True): + def churn_rows( + self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True + ): assert self.expect_rows >= n max_iters = 10 diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 39d4a3a6d7..420055ac3a 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -4,7 +4,7 @@ import enum import json import os import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + AGGRESIVE_COMPACTION_TENANT_CONF = { # Disable gc and compaction. The test runs compaction manually. "gc_period": "0s", diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 577a3a25ca..11ebb81023 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -15,7 +15,7 @@ import enum import os import re import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TimelineId @@ -40,6 +40,10 @@ from fixtures.remote_storage import ( from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index d1d6b3af75..b1abcaa763 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload from pytest_httpserver import HTTPServer +from typing_extensions import override from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response @@ -954,6 +955,7 @@ class PageserverFailpoint(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.allowed_errors.extend( @@ -961,19 +963,23 @@ class PageserverFailpoint(Failure): ) pageserver.http_client().configure_failpoints((self.failpoint, "return(1)")) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.http_client().configure_failpoints((self.failpoint, "off")) if self._mitigate: env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"}) + @override def expect_available(self): return True + @override def can_mitigate(self): return self._mitigate - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure): self.pageserver_id = None self.action = action + @override def apply(self, env: NeonEnv): env.storage_controller.configure_failpoints((self.failpoint, self.action)) + @override def clear(self, env: NeonEnv): if "panic" in self.action: log.info("Restarting storage controller after panic") @@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure): else: env.storage_controller.configure_failpoints((self.failpoint, "off")) + @override def expect_available(self): # Controller panics _do_ leave pageservers available, but our test code relies # on using the locate API to update configurations in Workload, so we must skip # these actions when the controller has been panicked. return "panic" not in self.action + @override def can_mitigate(self): return False - def fails_forward(self, env): + @override + def fails_forward(self, env: NeonEnv): # Edge case: the very last failpoint that simulates a DB connection error, where # the abort path will fail-forward and result in a complete split. fail_forward = self.failpoint == "shard-split-post-complete" @@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure): return fail_forward + @override def expect_exception(self): if "panic" in self.action: return requests.exceptions.ConnectionError @@ -1029,18 +1041,22 @@ class NodeKill(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.stop(immediate=True) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.start() + @override def expect_available(self): return False - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -1059,21 +1075,26 @@ class CompositeFailure(Failure): self.pageserver_id = f.pageserver_id break + @override def apply(self, env: NeonEnv): for f in self.failures: f.apply(env) - def clear(self, env): + @override + def clear(self, env: NeonEnv): for f in self.failures: f.clear(env) + @override def expect_available(self): return all(f.expect_available() for f in self.failures) - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): for f in self.failures: f.mitigate(env) + @override def expect_exception(self): expect = set(f.expect_exception() for f in self.failures) @@ -1211,7 +1232,7 @@ def test_sharding_split_failures( assert attached_count == initial_shard_count - def assert_split_done(exclude_ps_id=None) -> None: + def assert_split_done(exclude_ps_id: Optional[int] = None) -> None: secondary_count = 0 attached_count = 0 for ps in env.pageservers: diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 202634477c..7be4d2ce0c 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -1038,7 +1038,7 @@ def test_storage_controller_tenant_deletion( ) # Break the compute hook: we are checking that deletion does not depend on the compute hook being available - def break_hook(): + def break_hook(_body: Any): raise RuntimeError("Unexpected call to compute hook") compute_reconfigure_listener.register_on_notify(break_hook) diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index f999edc067..05db0fe977 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -6,7 +6,7 @@ import shutil import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + @pytest.mark.parametrize("shard_count", [None, 4]) def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): From 306094a87d505c2319307e4a3f8929c849be1abe Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 9 Oct 2024 22:43:35 +0100 Subject: [PATCH 03/18] add local-proxy suffix to wake-compute requests, respect the returned port (#9298) https://github.com/neondatabase/cloud/issues/18349 Use the `-local-proxy` suffix to make sure we get the 10432 local_proxy port back from cplane. --- proxy/src/serverless/backend.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8a8f38d181..f54476b51d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -28,7 +28,7 @@ use crate::{ retry::{CouldRetry, ShouldRetryWakeCompute}, }, rate_limiter::EndpointRateLimiter, - Host, + EndpointId, Host, }; use super::{ @@ -222,7 +222,14 @@ impl PoolingBackend { .auth_backend .as_ref() .map(|()| ComputeCredentials { - info: conn_info.user_info.clone(), + info: ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!( + "{}-local-proxy", + conn_info.user_info.endpoint + )), + options: conn_info.user_info.options.clone(), + }, keys: crate::auth::backend::ComputeCredentialKeys::None, }); crate::proxy::connect_compute::connect_to_compute( @@ -507,8 +514,12 @@ impl ConnectMechanism for HyperMechanism { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - // let port = node_info.config.get_ports().first().unwrap_or_else(10432); - let res = connect_http2(&host, 10432, timeout).await; + let port = *node_info.config.get_ports().first().ok_or_else(|| { + HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress( + "local-proxy port missing on compute address".into(), + )) + })?; + let res = connect_http2(&host, port, timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?; From 426b1c5f0887f45cc731f8786c457fb02573e0cc Mon Sep 17 00:00:00 2001 From: John Spray Date: Thu, 10 Oct 2024 12:26:43 +0100 Subject: [PATCH 04/18] storage controller: use 'infra' JWT scope for node registration (#9343) ## Problem Storage controller `/control` API mostly requires admin tokens, for interactive use by engineers. But for endpoints used by scripts, we should not require admin tokens. Discussion at https://neondb.slack.com/archives/C033RQ5SPDH/p1728550081788989?thread_ts=1728548232.265019&cid=C033RQ5SPDH ## Summary of changes - Introduce the 'infra' JWT scope, which was not previously used in the neon repo - For pageserver & safekeeper node registrations, require infra scope instead of admin Note that admin will still work, as the controller auth checks permit admin tokens for all endpoints irrespective of what scope they require. --- libs/utils/src/auth.rs | 5 ++++- pageserver/src/auth.rs | 23 ++++++++++++++--------- safekeeper/src/auth.rs | 23 ++++++++++++++--------- storage_controller/src/http.rs | 4 ++-- 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index 7b735875b7..5bd6f4bedc 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -31,9 +31,12 @@ pub enum Scope { /// The scope used by pageservers in upcalls to storage controller and cloud control plane #[serde(rename = "generations_api")] GenerationsApi, - /// Allows access to control plane managment API and some storage controller endpoints. + /// Allows access to control plane managment API and all storage controller endpoints. Admin, + /// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration) + Infra, + /// Allows access to storage controller APIs used by the scrubber, to interrogate the state /// of a tenant & post scrub results. Scrubber, diff --git a/pageserver/src/auth.rs b/pageserver/src/auth.rs index 9e3dedb75a..5c931fcfdb 100644 --- a/pageserver/src/auth.rs +++ b/pageserver/src/auth.rs @@ -14,14 +14,19 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } (Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope (Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope - (Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Pageserver auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::SafekeeperData + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Pageserver auth", + claims.scope + ) + .into(), + )), } } diff --git a/safekeeper/src/auth.rs b/safekeeper/src/auth.rs index c5c9393c00..fdd0830b02 100644 --- a/safekeeper/src/auth.rs +++ b/safekeeper/src/auth.rs @@ -15,15 +15,20 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } Ok(()) } - (Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Safekeeper auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::PageServerApi + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Safekeeper auth", + claims.scope + ) + .into(), + )), (Scope::SafekeeperData, _) => Ok(()), } } diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 4dd8badd03..46b6f4f2bf 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -636,7 +636,7 @@ async fn handle_tenant_list( } async fn handle_node_register(req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let mut req = match maybe_forward(req).await { ForwardOutcome::Forwarded(res) => { @@ -1182,7 +1182,7 @@ async fn handle_get_safekeeper(req: Request) -> Result, Api /// Assumes information is only relayed to storage controller after first selecting an unique id on /// control plane database, which means we have an id field in the request and payload. async fn handle_upsert_safekeeper(mut req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let body = json_request::(&mut req).await?; let id = parse_request_param::(&req, "id")?; From c2623ffef454378b2602f494e459b32028aa04a0 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Thu, 10 Oct 2024 12:40:35 +0100 Subject: [PATCH 05/18] CODEOWNERS: assign `storage_scrubber` to storage (#9346) --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/CODEOWNERS b/CODEOWNERS index 606dbb4e22..f8ed4be816 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,6 @@ /compute_tools/ @neondatabase/control-plane @neondatabase/compute /storage_controller @neondatabase/storage +/storage_scrubber @neondatabase/storage /libs/pageserver_api/ @neondatabase/storage /libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage /libs/remote_storage/ @neondatabase/storage From 9dd80b9b4ce94addc0acc6200be22cd7b09ba562 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Thu, 10 Oct 2024 14:09:53 +0100 Subject: [PATCH 06/18] storage_scrubber: fix faulty assertion when no timelines (#9345) When there are no timelines in remote storage, the storage scrubber would incorrectly trip an assertion with "Must be set if results are present", referring to the last processed tenant ID. When there are no timelines we don't expect there to be a tenant ID either. The assertion was introduced in 37aa6fd. Only apply the assertion when any timelines are present. --- storage_scrubber/src/scan_pageserver_metadata.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/storage_scrubber/src/scan_pageserver_metadata.rs b/storage_scrubber/src/scan_pageserver_metadata.rs index c1ea589f7f..cb3299d413 100644 --- a/storage_scrubber/src/scan_pageserver_metadata.rs +++ b/storage_scrubber/src/scan_pageserver_metadata.rs @@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata( tenant_timeline_results.push((ttid, data)); } - let tenant_id = tenant_id.expect("Must be set if results are present"); - if !tenant_timeline_results.is_empty() { + let tenant_id = tenant_id.expect("Must be set if results are present"); analyze_tenant( &remote_client, tenant_id, From 264c34dfb7b90e619677fc04fdb957b270c40e8e Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Thu, 10 Oct 2024 10:26:23 -0500 Subject: [PATCH 07/18] Move path-related fixtures into their own module (#9304) neon_fixtures.py has grown into quite a beast. Signed-off-by: Tristan Partin --- test_runner/conftest.py | 1 + test_runner/fixtures/neon_fixtures.py | 240 +---------------------- test_runner/fixtures/paths.py | 269 ++++++++++++++++++++++++++ 3 files changed, 273 insertions(+), 237 deletions(-) create mode 100644 test_runner/fixtures/paths.py diff --git a/test_runner/conftest.py b/test_runner/conftest.py index d6e7fcf7ca..4a3194c691 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -6,6 +6,7 @@ pytest_plugins = ( "fixtures.httpserver", "fixtures.compute_reconfigure", "fixtures.storage_controller_proxy", + "fixtures.paths", "fixtures.neon_fixtures", "fixtures.benchmark_fixture", "fixtures.pg_stats", diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index f81bc3f5a6..9a60de922c 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -18,7 +18,6 @@ from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime from enum import Enum -from fcntl import LOCK_EX, LOCK_UN, flock from functools import cached_property from pathlib import Path from types import TracebackType @@ -59,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import ( wait_for_last_record_lsn, ) +from fixtures.paths import get_test_repo_dir, shared_snapshot_dir from fixtures.pg_version import PgVersion from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import ( @@ -76,7 +76,6 @@ from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, allure_add_grafana_links, - allure_attach_from_dir, assert_no_errors, get_dir_size, print_gc_result, @@ -96,6 +95,8 @@ if TYPE_CHECKING: Union, ) + from fixtures.paths import SnapshotDirLocked + T = TypeVar("T") @@ -118,65 +119,11 @@ put directly-importable functions into utils.py or another separate file. Env = dict[str, str] -DEFAULT_OUTPUT_DIR: str = "test_output" DEFAULT_BRANCH_NAME: str = "main" BASE_PORT: int = 15000 -@pytest.fixture(scope="session") -def base_dir() -> Iterator[Path]: - # find the base directory (currently this is the git root) - base_dir = Path(__file__).parents[2] - log.info(f"base_dir is {base_dir}") - - yield base_dir - - -@pytest.fixture(scope="function") -def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: - if os.getenv("REMOTE_ENV"): - # we are in remote env and do not have neon binaries locally - # this is the case for benchmarks run on self-hosted runner - return - - # Find the neon binaries. - if env_neon_bin := os.environ.get("NEON_BIN"): - binpath = Path(env_neon_bin) - else: - binpath = base_dir / "target" / build_type - log.info(f"neon_binpath is {binpath}") - - if not (binpath / "pageserver").exists(): - raise Exception(f"neon binaries not found at '{binpath}'") - - yield binpath - - -@pytest.fixture(scope="session") -def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: - if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): - distrib_dir = Path(env_postgres_bin).resolve() - else: - distrib_dir = base_dir / "pg_install" - - log.info(f"pg_distrib_dir is {distrib_dir}") - yield distrib_dir - - -@pytest.fixture(scope="session") -def top_output_dir(base_dir: Path) -> Iterator[Path]: - # Compute the top-level directory for all tests. - if env_test_output := os.environ.get("TEST_OUTPUT"): - output_dir = Path(env_test_output).resolve() - else: - output_dir = base_dir / DEFAULT_OUTPUT_DIR - output_dir.mkdir(exist_ok=True) - - log.info(f"top_output_dir is {output_dir}") - yield output_dir - - @pytest.fixture(scope="session") def neon_api_key() -> str: api_key = os.getenv("NEON_API_KEY") @@ -4246,44 +4193,6 @@ class StorageScrubber: raise -def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path: - """Compute the path to a working directory for an individual test.""" - test_name = request.node.name - test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}" - - # We rerun flaky tests multiple times, use a separate directory for each run. - if (suffix := getattr(request.node, "execution_count", None)) is not None: - test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" - - log.info(f"get_test_output_dir is {test_dir}") - # make mypy happy - assert isinstance(test_dir, Path) - return test_dir - - -def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - The working directory for a test. - """ - return _get_test_dir(request, top_output_dir, "") - - -def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - Directory that contains `upperdir` and `workdir` for overlayfs mounts - that a test creates. See `NeonEnvBuilder.overlay_mount`. - """ - return _get_test_dir(request, top_output_dir, "overlay-") - - -def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: - return top_output_dir / "shared-snapshots" / snapshot_name - - -def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - return get_test_output_dir(request, top_output_dir) / "repo" - - def pytest_addoption(parser: Parser): parser.addoption( "--preserve-database-files", @@ -4298,149 +4207,6 @@ SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile( ) -# This is autouse, so the test output directory always gets created, even -# if a test doesn't put anything there. -# -# NB: we request the overlay dir fixture so the fixture does its cleanups -@pytest.fixture(scope="function", autouse=True) -def test_output_dir( - request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path -) -> Iterator[Path]: - """Create the working directory for an individual test.""" - - # one directory per test - test_dir = get_test_output_dir(request, top_output_dir) - log.info(f"test_output_dir is {test_dir}") - shutil.rmtree(test_dir, ignore_errors=True) - test_dir.mkdir() - - yield test_dir - - # Allure artifacts creation might involve the creation of `.tar.zst` archives, - # which aren't going to be used if Allure results collection is not enabled - # (i.e. --alluredir is not set). - # Skip `allure_attach_from_dir` in this case - if not request.config.getoption("--alluredir"): - return - - preserve_database_files = False - for k, v in request.node.user_properties: - # NB: the neon_env_builder fixture uses this fixture (test_output_dir). - # So, neon_env_builder's cleanup runs before here. - # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. - if k == "preserve_database_files": - assert isinstance(v, bool) - preserve_database_files = v - - allure_attach_from_dir(test_dir, preserve_database_files) - - -class FileAndThreadLock: - def __init__(self, path: Path): - self.path = path - self.thread_lock = threading.Lock() - self.fd: Optional[int] = None - - def __enter__(self): - self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) - # lock thread lock before file lock so that there's no race - # around flocking / funlocking the file lock - self.thread_lock.acquire() - flock(self.fd, LOCK_EX) - - def __exit__(self, exc_type, exc_value, exc_traceback): - assert self.fd is not None - assert self.thread_lock.locked() # ... by us - flock(self.fd, LOCK_UN) - self.thread_lock.release() - os.close(self.fd) - self.fd = None - - -class SnapshotDirLocked: - def __init__(self, parent: SnapshotDir): - self._parent = parent - - def is_initialized(self): - # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. - # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. - return self._parent._marker_file_path.exists() - - def set_initialized(self): - self._parent._marker_file_path.write_text("") - - @property - def path(self) -> Path: - return self._parent._path / "snapshot" - - -class SnapshotDir: - _path: Path - - def __init__(self, path: Path): - self._path = path - assert self._path.is_dir() - self._lock = FileAndThreadLock(self._lock_file_path) - - @property - def _lock_file_path(self) -> Path: - return self._path / "initializing.flock" - - @property - def _marker_file_path(self) -> Path: - return self._path / "initialized.marker" - - def __enter__(self) -> SnapshotDirLocked: - self._lock.__enter__() - return SnapshotDirLocked(self) - - def __exit__(self, exc_type, exc_value, exc_traceback): - self._lock.__exit__(exc_type, exc_value, exc_traceback) - - -def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir: - snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) - snapshot_dir_path.mkdir(exist_ok=True, parents=True) - return SnapshotDir(snapshot_dir_path) - - -@pytest.fixture(scope="function") -def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: - """ - Idempotently create a test's overlayfs mount state directory. - If the functionality isn't enabled via env var, returns None. - - The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). - """ - - if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: - return None - - overlay_dir = get_test_overlay_dir(request, top_output_dir) - log.info(f"test_overlay_dir is {overlay_dir}") - - overlay_dir.mkdir(exist_ok=True) - # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` - for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): - cmd = ["sudo", "umount", str(mountpoint)] - log.info( - f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" - ) - subprocess.run(cmd, capture_output=True, check=True) - # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. - cmd = ["sudo", "rm", "-rf", str(overlay_dir)] - subprocess.run(cmd, capture_output=True, check=True) - - overlay_dir.mkdir() - - return overlay_dir - - # no need to clean up anything: on clean shutdown, - # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup - # and on unclean shutdown, this function will take care of it - # on the next test run - - SKIP_DIRS = frozenset( ( "pg_wal", diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py new file mode 100644 index 0000000000..0712d241db --- /dev/null +++ b/test_runner/fixtures/paths.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import threading +from fcntl import LOCK_EX, LOCK_UN, flock +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING + +import pytest +from pytest import FixtureRequest + +from fixtures import overlayfs +from fixtures.log_helper import log +from fixtures.utils import allure_attach_from_dir + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Optional + + +DEFAULT_OUTPUT_DIR: str = "test_output" + + +def get_test_dir( + request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None +) -> Path: + """Compute the path to a working directory for an individual test.""" + test_name = request.node.name + test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}" + + # We rerun flaky tests multiple times, use a separate directory for each run. + if (suffix := getattr(request.node, "execution_count", None)) is not None: + test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" + + return test_dir + + +def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + The working directory for a test. + """ + return get_test_dir(request, top_output_dir) + + +def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + Directory that contains `upperdir` and `workdir` for overlayfs mounts + that a test creates. See `NeonEnvBuilder.overlay_mount`. + """ + return get_test_dir(request, top_output_dir, "overlay-") + + +def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: + return top_output_dir / "shared-snapshots" / snapshot_name + + +def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + return get_test_output_dir(request, top_output_dir) / "repo" + + +@pytest.fixture(scope="session") +def base_dir() -> Iterator[Path]: + # find the base directory (currently this is the git root) + base_dir = Path(__file__).parents[2] + log.info(f"base_dir is {base_dir}") + + yield base_dir + + +@pytest.fixture(scope="function") +def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + # we are in remote env and do not have neon binaries locally + # this is the case for benchmarks run on self-hosted runner + return + + # Find the neon binaries. + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + binpath = base_dir / "target" / build_type + log.info(f"neon_binpath is {binpath}") + + if not (binpath / "pageserver").exists(): + raise Exception(f"neon binaries not found at '{binpath}'") + + yield binpath + + +@pytest.fixture(scope="session") +def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: + if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): + distrib_dir = Path(env_postgres_bin).resolve() + else: + distrib_dir = base_dir / "pg_install" + + log.info(f"pg_distrib_dir is {distrib_dir}") + yield distrib_dir + + +@pytest.fixture(scope="session") +def top_output_dir(base_dir: Path) -> Iterator[Path]: + # Compute the top-level directory for all tests. + if env_test_output := os.environ.get("TEST_OUTPUT"): + output_dir = Path(env_test_output).resolve() + else: + output_dir = base_dir / DEFAULT_OUTPUT_DIR + output_dir.mkdir(exist_ok=True) + + log.info(f"top_output_dir is {output_dir}") + yield output_dir + + +# This is autouse, so the test output directory always gets created, even +# if a test doesn't put anything there. +# +# NB: we request the overlay dir fixture so the fixture does its cleanups +@pytest.fixture(scope="function", autouse=True) +def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]: + """Create the working directory for an individual test.""" + + # one directory per test + test_dir = get_test_output_dir(request, top_output_dir) + log.info(f"test_output_dir is {test_dir}") + shutil.rmtree(test_dir, ignore_errors=True) + test_dir.mkdir() + + yield test_dir + + # Allure artifacts creation might involve the creation of `.tar.zst` archives, + # which aren't going to be used if Allure results collection is not enabled + # (i.e. --alluredir is not set). + # Skip `allure_attach_from_dir` in this case + if not request.config.getoption("--alluredir"): + return + + preserve_database_files = False + for k, v in request.node.user_properties: + # NB: the neon_env_builder fixture uses this fixture (test_output_dir). + # So, neon_env_builder's cleanup runs before here. + # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. + if k == "preserve_database_files": + assert isinstance(v, bool) + preserve_database_files = v + + allure_attach_from_dir(test_dir, preserve_database_files) + + +class FileAndThreadLock: + def __init__(self, path: Path): + self.path = path + self.thread_lock = threading.Lock() + self.fd: Optional[int] = None + + def __enter__(self): + self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) + # lock thread lock before file lock so that there's no race + # around flocking / funlocking the file lock + self.thread_lock.acquire() + flock(self.fd, LOCK_EX) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + assert self.fd is not None + assert self.thread_lock.locked() # ... by us + flock(self.fd, LOCK_UN) + self.thread_lock.release() + os.close(self.fd) + self.fd = None + + +class SnapshotDirLocked: + def __init__(self, parent: SnapshotDir): + self._parent = parent + + def is_initialized(self): + # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. + # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. + return self._parent.marker_file_path.exists() + + def set_initialized(self): + self._parent.marker_file_path.write_text("") + + @property + def path(self) -> Path: + return self._parent.path / "snapshot" + + +class SnapshotDir: + _path: Path + + def __init__(self, path: Path): + self._path = path + assert self._path.is_dir() + self._lock = FileAndThreadLock(self.lock_file_path) + + @property + def path(self) -> Path: + return self._path + + @property + def lock_file_path(self) -> Path: + return self._path / "initializing.flock" + + @property + def marker_file_path(self) -> Path: + return self._path / "initialized.marker" + + def __enter__(self) -> SnapshotDirLocked: + self._lock.__enter__() + return SnapshotDirLocked(self) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + self._lock.__exit__(exc_type, exc_value, exc_traceback) + + +def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir: + snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) + snapshot_dir_path.mkdir(exist_ok=True, parents=True) + return SnapshotDir(snapshot_dir_path) + + +@pytest.fixture(scope="function") +def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: + """ + Idempotently create a test's overlayfs mount state directory. + If the functionality isn't enabled via env var, returns None. + + The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). + """ + + if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: + return None + + overlay_dir = get_test_overlay_dir(request, top_output_dir) + log.info(f"test_overlay_dir is {overlay_dir}") + + overlay_dir.mkdir(exist_ok=True) + # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` + for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): + cmd = ["sudo", "umount", str(mountpoint)] + log.info( + f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" + ) + subprocess.run(cmd, capture_output=True, check=True) + # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. + cmd = ["sudo", "rm", "-rf", str(overlay_dir)] + subprocess.run(cmd, capture_output=True, check=True) + + overlay_dir.mkdir() + + return overlay_dir + + # no need to clean up anything: on clean shutdown, + # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup + # and on unclean shutdown, this function will take care of it + # on the next test run From 07c714343f793eeb866232e23b4c1c7409fa7f61 Mon Sep 17 00:00:00 2001 From: John Spray Date: Thu, 10 Oct 2024 17:06:42 +0100 Subject: [PATCH 08/18] tests: allow a log warning in test_cli_start_stop_multi (#9320) ## Problem This test restarts services in an undefined order (whatever neon_local does), which means we should be tolerant of warnings that come from restarting the storage controller while a pageserver is running. We can see failures with warnings from dropped requests, e.g. https://neon-github-public-dev.s3.amazonaws.com/reports/pr-9307/11229000712/index.html#/testresult/d33d5cb206331e28 ``` WARN request{method=GET path=/v1/location_config request_id=b7dbda15-6efb-4610-8b19-a3772b65455f}: request was dropped before completing\n') ``` ## Summary of changes - allow-list the `request was dropped before completing` message on pageservers before restarting services --- test_runner/regress/test_neon_cli.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test_runner/regress/test_neon_cli.py b/test_runner/regress/test_neon_cli.py index 3a0a4b10bf..783fb813cf 100644 --- a/test_runner/regress/test_neon_cli.py +++ b/test_runner/regress/test_neon_cli.py @@ -162,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder): env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID) env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1) + # We will stop the storage controller while it may have requests in + # flight, and the pageserver complains when requests are abandoned. + for ps in env.pageservers: + ps.allowed_errors.append(".*request was dropped before completing.*") + # Keep NeonEnv state up to date, it usually owns starting/stopping services env.pageservers[0].running = False env.pageservers[1].running = False From 1f7904c917503a95f6297ae9df705e22fd5daba4 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Thu, 10 Oct 2024 12:40:30 -0500 Subject: [PATCH 09/18] Enable cargo caching in check-codestyle-rust This job takes an extraordinary amount of time for what I understand it to do. The obvious win is caching dependencies. Rory disabled caching in cd5732d9d8ccd291f39ed41250072acdce3012e6. I assume this was to get gen3 runners up and running. Signed-off-by: Tristan Partin --- .github/workflows/build_and_test.yml | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a759efb56c..e7193cfe19 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -193,16 +193,15 @@ jobs: with: submodules: true -# Disabled for now -# - name: Restore cargo deps cache -# id: cache_cargo -# uses: actions/cache@v4 -# with: -# path: | -# !~/.cargo/registry/src -# ~/.cargo/git/ -# target/ -# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }} + - name: Cache cargo deps + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + !~/.cargo/registry/src + ~/.cargo/git + target + key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust # Some of our rust modules use FFI and need those to be checked - name: Get postgres headers From 006d9dfb6bde9473c14719cab8ecebec77dd65c7 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Thu, 10 Oct 2024 12:43:40 -0500 Subject: [PATCH 10/18] Add compute_config_dir fixture Allows easy access to various compute config files. Signed-off-by: Tristan Partin --- test_runner/fixtures/paths.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py index 0712d241db..cffeb47ee8 100644 --- a/test_runner/fixtures/paths.py +++ b/test_runner/fixtures/paths.py @@ -70,6 +70,14 @@ def base_dir() -> Iterator[Path]: yield base_dir +@pytest.fixture(scope="session") +def compute_config_dir(base_dir: Path) -> Iterator[Path]: + """ + Retrieve the path to the compute configuration directory. + """ + yield base_dir / "compute" / "etc" + + @pytest.fixture(scope="function") def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: if os.getenv("REMOTE_ENV"): From 53147b51f90ba854605e49edd28b7d7895930c92 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Thu, 10 Oct 2024 13:00:25 -0500 Subject: [PATCH 11/18] Use valid type hints for Python 3.9 I have no idea how this made it past the linters. Signed-off-by: Tristan Partin --- test_runner/fixtures/neon_api.py | 6 +++--- test_runner/fixtures/pageserver/http.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 683ea3af44..5934baccff 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -185,8 +185,8 @@ class NeonAPI: def get_connection_uri( self, project_id: str, - branch_id: str | None = None, - endpoint_id: str | None = None, + branch_id: Optional[str] = None, + endpoint_id: Optional[str] = None, database_name: str = "neondb", role_name: str = "neondb_owner", pooled: bool = True, @@ -262,7 +262,7 @@ class NeonAPI: class NeonApiEndpoint: - def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None): + def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]): self.neon_api = neon_api if project_id is None: project = neon_api.create_project(pg_version) diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 84a7e5f0a2..aa4435af4e 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -886,7 +886,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, - batch_size: int | None = None, + batch_size: Optional[int] = None, **kwargs, ) -> set[TimelineId]: params = {} From b2ecbf3e80804123b216cb3242d0e165936db120 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Fri, 11 Oct 2024 10:45:55 +0300 Subject: [PATCH 12/18] Introduce "quota" ErrorKind (#9300) ## Problem Fixes #8340 ## Summary of changes Introduced ErrorKind::quota to handle quota-related errors ## Checklist before requesting a review - [x] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist --- proxy/src/control_plane/provider/mod.rs | 16 ++++++++-------- proxy/src/error.rs | 5 +++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/proxy/src/control_plane/provider/mod.rs b/proxy/src/control_plane/provider/mod.rs index 01d93dee43..6cc525a324 100644 --- a/proxy/src/control_plane/provider/mod.rs +++ b/proxy/src/control_plane/provider/mod.rs @@ -81,12 +81,12 @@ pub(crate) mod errors { Reason::EndpointNotFound => ErrorKind::User, Reason::BranchNotFound => ErrorKind::User, Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, - Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User, - Reason::ActiveTimeQuotaExceeded => ErrorKind::User, - Reason::ComputeTimeQuotaExceeded => ErrorKind::User, - Reason::WrittenDataQuotaExceeded => ErrorKind::User, - Reason::DataTransferQuotaExceeded => ErrorKind::User, - Reason::LogicalSizeQuotaExceeded => ErrorKind::User, + Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota, + Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota, + Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota, + Reason::WrittenDataQuotaExceeded => ErrorKind::Quota, + Reason::DataTransferQuotaExceeded => ErrorKind::Quota, + Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota, Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane, Reason::LockAlreadyTaken => ErrorKind::ControlPlane, Reason::RunningOperations => ErrorKind::ControlPlane, @@ -103,7 +103,7 @@ pub(crate) mod errors { } if error .contains("compute time quota of non-primary branches is exceeded") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } ControlPlaneError { http_status_code: http::StatusCode::LOCKED, @@ -112,7 +112,7 @@ pub(crate) mod errors { } if error.contains("quota exceeded") || error.contains("the limit for current plan reached") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } ControlPlaneError { http_status_code: http::StatusCode::TOO_MANY_REQUESTS, diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 53f9f75c5b..1cd4dc2c22 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -49,6 +49,10 @@ pub enum ErrorKind { #[label(rename = "serviceratelimit")] ServiceRateLimit, + /// Proxy quota limit violation + #[label(rename = "quota")] + Quota, + /// internal errors Service, @@ -70,6 +74,7 @@ impl ErrorKind { ErrorKind::ClientDisconnect => "clientdisconnect", ErrorKind::RateLimit => "ratelimit", ErrorKind::ServiceRateLimit => "serviceratelimit", + ErrorKind::Quota => "quota", ErrorKind::Service => "service", ErrorKind::ControlPlane => "controlplane", ErrorKind::Postgres => "postgres", From 184935619e55bbd9c025b5a057f36362b1a60dd2 Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 11 Oct 2024 09:41:08 +0100 Subject: [PATCH 13/18] tests: stabilize test_storage_controller_heartbeats (#9347) ## Problem This could fail with `reconciliation in progress` if running on a slow test node such that background reconciliation happens at the same time as we call consistency_check. Example: https://neon-github-public-dev.s3.amazonaws.com/reports/main/11258171952/index.html#/testresult/54889c9469afb232 ## Summary of changes - Call reconcile_until_idle before calling consistency check once, rather than calling consistency check until it passes --- test_runner/regress/test_storage_controller.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 7be4d2ce0c..1dcf0b254d 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -1300,11 +1300,11 @@ def test_storage_controller_heartbeats( node_to_tenants = build_node_to_tenants_map(env) log.info(f"Back online: {node_to_tenants=}") - # ... expecting the storage controller to reach a consistent state - def storage_controller_consistent(): - env.storage_controller.consistency_check() + # ... background reconciliation may need to run to clean up the location on the node that was offline + env.storage_controller.reconcile_until_idle() - wait_until(30, 1, storage_controller_consistent) + # ... expecting the storage controller to reach a consistent state + env.storage_controller.consistency_check() def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder): From 6baf1aae3315c10b20f8e5e27239d3604484b895 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Fri, 11 Oct 2024 11:29:08 +0200 Subject: [PATCH 14/18] proxy: Demote some errors to warnings in logs (#9354) --- proxy/src/control_plane/provider/neon.rs | 4 ++-- proxy/src/proxy/mod.rs | 12 ++++++------ proxy/src/proxy/passthrough.rs | 2 +- .../redis/connection_with_credentials_provider.rs | 6 +++--- proxy/src/redis/notifications.rs | 2 +- proxy/src/serverless/mod.rs | 6 +++--- proxy/src/serverless/sql_over_http.rs | 4 ++-- proxy/src/usage_metrics.rs | 4 ++-- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/proxy/src/control_plane/provider/neon.rs b/proxy/src/control_plane/provider/neon.rs index e5f8b5c741..d01878741c 100644 --- a/proxy/src/control_plane/provider/neon.rs +++ b/proxy/src/control_plane/provider/neon.rs @@ -22,7 +22,7 @@ use futures::TryFutureExt; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_postgres::config::SslMode; -use tracing::{debug, error, info, info_span, warn, Instrument}; +use tracing::{debug, info, info_span, warn, Instrument}; const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -456,7 +456,7 @@ async fn parse_body serde::Deserialize<'a>>( }); body.http_status_code = status; - error!("console responded with an error ({status}): {body:?}"); + warn!("console responded with an error ({status}): {body:?}"); Err(ApiError::ControlPlane(body)) } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 7003af2aba..9e1af88f41 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -35,7 +35,7 @@ use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, Instrument}; +use tracing::{error, info, warn, Instrument}; use self::{ connect_compute::{connect_to_compute, TcpMechanism}, @@ -95,15 +95,15 @@ pub async fn task_main( connections.spawn(async move { let (socket, peer_addr) = match read_proxy_protocol(socket).await { Err(e) => { - error!("per-client task finished with an error: {e:#}"); + warn!("per-client task finished with an error: {e:#}"); return; } Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - error!("missing required proxy protocol header"); + warn!("missing required proxy protocol header"); return; } Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - error!("proxy protocol header not supported"); + warn!("proxy protocol header not supported"); return; } Ok((socket, Some(addr))) => (socket, addr.ip()), @@ -144,7 +144,7 @@ pub async fn task_main( Err(e) => { // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - error!(parent: &span, "per-client task finished with an error: {e:#}"); + warn!(parent: &span, "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); @@ -155,7 +155,7 @@ pub async fn task_main( match p.proxy_pass().instrument(span.clone()).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index bbea47f8af..497cf4bfd5 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -71,7 +71,7 @@ impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { - tracing::error!(?err, "could not cancel the query in the database"); + tracing::warn!(?err, "could not cancel the query in the database"); } res } diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 2de66b58b1..ccd48f1481 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -6,7 +6,7 @@ use redis::{ ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, }; use tokio::task::JoinHandle; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use super::elasticache::CredentialsProvider; @@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider { return Ok(()); } Err(e) => { - error!("Error during PING: {e:?}"); + warn!("Error during PING: {e:?}"); } } } else { @@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider { info!("Connection succesfully established"); } Err(e) => { - error!("Connection is broken. Error during PING: {e:?}"); + warn!("Connection is broken. Error during PING: {e:?}"); } } self.con = Some(con); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 36a3443603..c3af6740cb 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -146,7 +146,7 @@ impl MessageHandler { { Ok(()) => {} Err(e) => { - tracing::error!("failed to cancel session: {e}"); + tracing::warn!("failed to cancel session: {e}"); } } } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 9be6b592bd..b5820b0535 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -48,7 +48,7 @@ use std::pin::{pin, Pin}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn, Instrument}; +use tracing::{info, warn, Instrument}; use utils::http::error::ApiError; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; @@ -241,7 +241,7 @@ async fn connection_startup( let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, Err(e) => { - tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); + tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); return None; } }; @@ -405,7 +405,7 @@ async fn request_handler( ) .await { - error!("error in websocket connection: {e:#}"); + warn!("error in websocket connection: {e:#}"); } } .instrument(span), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f7c3b26917..646e7f8a52 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -831,7 +831,7 @@ impl QueryData { Either::Right((_cancelled, query)) => { tracing::info!("cancelling query"); if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // wait for the query cancellation match time::timeout(time::Duration::from_millis(100), query).await { @@ -920,7 +920,7 @@ impl BatchQueryData { } Err(SqlOverHttpError::Cancelled(_)) => { if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. discard.discard(); diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index bd3e62bc12..ee36ed462d 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -27,7 +27,7 @@ use std::{ }; use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; -use tracing::{error, info, instrument, trace}; +use tracing::{error, info, instrument, trace, warn}; use utils::backoff; use uuid::{NoContext, Timestamp}; @@ -346,7 +346,7 @@ async fn collect_metrics_iteration( error!("metrics endpoint refused the sent metrics: {:?}", res); for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) { // Report if the metric value is suspiciously large - error!("potentially abnormal metric value: {:?}", metric); + warn!("potentially abnormal metric value: {:?}", metric); } } } From 326cd80f0dd8b60e5780d184bd55dab769a9f0b1 Mon Sep 17 00:00:00 2001 From: Fedor Dikarev Date: Fri, 11 Oct 2024 14:46:45 +0200 Subject: [PATCH 15/18] ci: gh-workflow-stats-action v0.1.4: remove debug output and proper pagination (#9356) ## Problem In previous version pagination didn't work so we collect information only for first 30 jobs in WorkflowRun --- .github/workflows/report-workflow-stats.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/report-workflow-stats.yml b/.github/workflows/report-workflow-stats.yml index 1afe896600..6abeff7695 100644 --- a/.github/workflows/report-workflow-stats.yml +++ b/.github/workflows/report-workflow-stats.yml @@ -33,7 +33,7 @@ jobs: actions: read steps: - name: Export GH Workflow Stats - uses: fedordikarev/gh-workflow-stats-action@v0.1.2 + uses: neondatabase/gh-workflow-stats-action@v0.1.4 with: DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} DB_TABLE: "gh_workflow_stats_neon" From 091a175a3e02d319468efe15bb9765a3a9e29f4b Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:29:54 +0200 Subject: [PATCH 16/18] Test versions mismatch (#9167) ## Problem We faced the problem of incompatibility of the different components of different versions. This should be detected automatically to prevent production bugs. ## Summary of changes The test for this situation was implemented Co-authored-by: Alexander Bayandin --- test_runner/README.md | 12 +++ test_runner/fixtures/neon_fixtures.py | 52 +++++++++++ test_runner/fixtures/paths.py | 37 +++++++- test_runner/fixtures/utils.py | 33 +++++++ test_runner/regress/test_compatibility.py | 91 +++++++++++++------ .../regress/test_storage_controller.py | 12 ++- 6 files changed, 206 insertions(+), 31 deletions(-) diff --git a/test_runner/README.md b/test_runner/README.md index d754e60d17..e087241c1f 100644 --- a/test_runner/README.md +++ b/test_runner/README.md @@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan Useful environment variables: `NEON_BIN`: The directory where neon binaries can be found. +`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found `POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found. Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain a subdirectory for each version with naming convention `v{PG_VERSION}/`. Inside that dir, a `bin/postgres` binary should be present. +`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found. `DEFAULT_PG_VERSION`: The version of Postgres to use, This is used to construct full path to the postgres binaries. Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16` @@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder): client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id) ``` +All the test which rely on NeonEnvBuilder, can check the various version combinations of the components. +To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions() +E.g. + +```python +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_something( +... +``` + For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html At the end of a test, all the nodes in the environment are automatically stopped, so you diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 9a60de922c..7789855fe4 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -75,6 +75,7 @@ from fixtures.safekeeper.http import SafekeeperHttpClient from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, + COMPONENT_BINARIES, allure_add_grafana_links, assert_no_errors, get_dir_size, @@ -316,11 +317,14 @@ class NeonEnvBuilder: run_id: uuid.UUID, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, test_name: str, top_output_dir: Path, test_output_dir: Path, + combination, test_overlay_dir: Optional[Path] = None, pageserver_remote_storage: Optional[RemoteStorage] = None, # toml that will be decomposed into `--config-override` flags during `pageserver --init` @@ -402,6 +406,19 @@ class NeonEnvBuilder: "test_" ), "Unexpectedly instantiated from outside a test function" self.test_name = test_name + self.compatibility_neon_binpath = compatibility_neon_binpath + self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir + self.version_combination = combination + self.mixdir = self.test_output_dir / "mixdir_neon" + if self.version_combination is not None: + assert ( + self.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions" + assert ( + self.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions" + self.mixdir.mkdir(mode=0o755, exist_ok=True) + self._mix_versions() def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv: # Cannot create more than one environment from one builder @@ -602,6 +619,21 @@ class NeonEnvBuilder: return self.env + def _mix_versions(self): + assert self.version_combination is not None, "version combination must be set" + for component, paths in COMPONENT_BINARIES.items(): + directory = ( + self.neon_binpath + if self.version_combination[component] == "new" + else self.compatibility_neon_binpath + ) + for filename in paths: + destination = self.mixdir / filename + destination.symlink_to(directory / filename) + if self.version_combination["compute"] == "old": + self.pg_distrib_dir = self.compatibility_pg_distrib_dir + self.neon_binpath = self.mixdir + def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path): """ Mount `srcdir` as an overlayfs mount at `dstdir`. @@ -1350,7 +1382,9 @@ def neon_simple_env( top_output_dir: Path, test_output_dir: Path, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, pageserver_virtual_file_io_engine: str, pageserver_aux_file_policy: Optional[AuxFileStore], @@ -1365,6 +1399,11 @@ def neon_simple_env( # Create the environment in the per-test output directory repo_dir = get_test_repo_dir(request, top_output_dir) + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) with NeonEnvBuilder( top_output_dir=top_output_dir, @@ -1372,7 +1411,9 @@ def neon_simple_env( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), @@ -1382,6 +1423,7 @@ def neon_simple_env( pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode, + combination=combination, ) as builder: env = builder.init_start() @@ -1395,7 +1437,9 @@ def neon_env_builder( port_distributor: PortDistributor, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, run_id: uuid.UUID, request: FixtureRequest, @@ -1422,6 +1466,11 @@ def neon_env_builder( # Create the environment in the test-specific output dir repo_dir = os.path.join(test_output_dir, "repo") + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) # Return the builder to the caller with NeonEnvBuilder( @@ -1430,7 +1479,10 @@ def neon_env_builder( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, + combination=combination, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py index cffeb47ee8..65f8e432b0 100644 --- a/test_runner/fixtures/paths.py +++ b/test_runner/fixtures/paths.py @@ -95,7 +95,29 @@ def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: if not (binpath / "pageserver").exists(): raise Exception(f"neon binaries not found at '{binpath}'") - yield binpath + yield binpath.absolute() + + +@pytest.fixture(scope="session") +def compatibility_snapshot_dir() -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + return + compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") + assert ( + compatibility_snapshot_dir_env is not None + ), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" + compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() + yield compatibility_snapshot_dir + + +@pytest.fixture(scope="session") +def compatibility_neon_binpath() -> Optional[Iterator[Path]]: + if os.getenv("REMOTE_ENV"): + return + comp_binpath = None + if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"): + comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute() + yield comp_binpath @pytest.fixture(scope="session") @@ -109,6 +131,19 @@ def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: yield distrib_dir +@pytest.fixture(scope="session") +def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]: + compat_distrib_dir = None + if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"): + compat_distrib_dir = Path(env_compat_postgres_bin).resolve() + if not compat_distrib_dir.exists(): + raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}") + + if compat_distrib_dir: + log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}") + yield compat_distrib_dir + + @pytest.fixture(scope="session") def top_output_dir(base_dir: Path) -> Iterator[Path]: # Compute the top-level directory for all tests. diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index ca1be35880..76575d330c 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -37,6 +37,23 @@ if TYPE_CHECKING: Fn = TypeVar("Fn", bound=Callable[..., Any]) +COMPONENT_BINARIES = { + "storage_controller": ("storage_controller",), + "storage_broker": ("storage_broker",), + "compute": ("compute_ctl",), + "safekeeper": ("safekeeper",), + "pageserver": ("pageserver", "pagectl"), +} +# Disable auto-formatting for better readability +# fmt: off +VERSIONS_COMBINATIONS = ( + {"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"}, +) +# fmt: on def subprocess_capture( @@ -607,3 +624,19 @@ def human_bytes(amt: float) -> str: amt = amt / 1024 raise RuntimeError("unreachable") + + +def allpairs_versions(): + """ + Returns a dictionary with arguments for pytest parametrize + to test the compatibility with the previous version of Neon components + combinations were pre-computed to test all the pairs of the components with + the different versions. + """ + ids = [] + for pair in VERSIONS_COMBINATIONS: + cur_id = [] + for component in sorted(pair.keys()): + cur_id.append(pair[component][0]) + ids.append(f"combination_{''.join(cur_id)}") + return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids} diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 791e38383e..96ba3dd5a4 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING +import fixtures.utils import pytest import toml from fixtures.common_types import TenantId, TimelineId @@ -93,6 +94,34 @@ if TYPE_CHECKING: # # Run forward compatibility test # ./scripts/pytest -k test_forward_compatibility # +# +# How to run `test_version_mismatch` locally: +# +# export DEFAULT_PG_VERSION=16 +# export BUILD_TYPE=release +# export CHECK_ONDISK_DATA_COMPATIBILITY=true +# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE} +# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install +# export NEON_BIN=target/release +# export POSTGRES_DISTRIB_DIR=pg_install +# +# # Build previous version of binaries and store them somewhere: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# mkdir -p neon_previous/target +# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE} +# cp -a pg_install ./neon_previous/pg_install +# +# # Build current version of binaries and create a data snapshot: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# ./scripts/pytest -k test_create_snapshot +# +# # Run the version mismatch test +# ./scripts/pytest -k test_version_mismatch + check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif( os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None, @@ -166,16 +195,11 @@ def test_backward_compatibility( neon_env_builder: NeonEnvBuilder, test_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the new binaries can read old data """ - compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") - assert ( - compatibility_snapshot_dir_env is not None - ), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" - compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() - breaking_changes_allowed = ( os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -214,27 +238,11 @@ def test_forward_compatibility( test_output_dir: Path, top_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the old binaries can read new data """ - compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN") - assert compatibility_neon_bin_env is not None, ( - "COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries " - "(ideally generated by the previous version of Neon)" - ) - compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve() - - compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR") - assert ( - compatibility_postgres_distrib_dir_env is not None - ), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)" - compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve() - - compatibility_snapshot_dir = ( - top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" - ) - breaking_changes_allowed = ( os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -245,9 +253,14 @@ def test_forward_compatibility( # Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.). # But always use the current version's neon_local binary. # This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI. - neon_env_builder.neon_binpath = compatibility_neon_bin - neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir - neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath + assert ( + neon_env_builder.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required" + assert ( + neon_env_builder.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required" + neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath + neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir env = neon_env_builder.from_repo_dir( compatibility_snapshot_dir / "repo", @@ -558,3 +571,29 @@ def test_historic_storage_formats( env.pageserver.http_client().timeline_compact( dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True ) + + +@check_ondisk_data_compatibility_if_enabled +@pytest.mark.xdist_group("compatibility") +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_versions_mismatch( + neon_env_builder: NeonEnvBuilder, + test_output_dir: Path, + pg_version: PgVersion, + compatibility_snapshot_dir, + combination, +): + """ + Checks compatibility of different combinations of versions of the components + """ + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.from_repo_dir( + compatibility_snapshot_dir / "repo", + ) + env.pageserver.allowed_errors.extend( + [".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"] + ) + env.start() + check_neon_works( + env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo" + ) diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 1dcf0b254d..1dcc37c407 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING +import fixtures.utils import pytest from fixtures.auth_tokens import TokenScope from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -38,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import RemoteStorageKind, s3_storage from fixtures.storage_controller_proxy import StorageControllerProxy -from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until +from fixtures.utils import ( + run_pg_bench_small, + subprocess_capture, + wait_until, +) from fixtures.workload import Workload from mypy_boto3_s3.type_defs import ( ObjectTypeDef, @@ -60,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids): return counts -def test_storage_controller_smoke( - neon_env_builder: NeonEnvBuilder, -): +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination): """ Test the basic lifecycle of a storage controller: - Restarting From 5ef805e12c0e14e222609c51337cf9afcddf3b92 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Fri, 11 Oct 2024 16:58:41 +0100 Subject: [PATCH 17/18] CI(run-python-test-set): allow to skip missing compatibility snapshot (#9365) ## Problem Action `run-python-test-set` fails if it is not used for `regress_tests` on release PR, because it expects `test_compatibility.py::test_create_snapshot` to generate a snapshot, and the test exists only in `regress_tests` suite. For example, in https://github.com/neondatabase/neon/pull/9291 [`test-postgres-client-libs`](https://github.com/neondatabase/neon/actions/runs/11209615321/job/31155111544) job failed. ## Summary of changes - Add `skip-if-does-not-exist` input to `.github/actions/upload` action (the same way we do for `.github/actions/download`) - Set `skip-if-does-not-exist=true` for "Upload compatibility snapshot" step in `run-python-test-set` action --- .github/actions/run-python-test-set/action.yml | 3 +++ .github/actions/upload/action.yml | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 4008cd0d36..330e875d56 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -218,6 +218,9 @@ runs: name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }} # Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/ + # The lack of compatibility snapshot shouldn't fail the job + # (for example if we didn't run the test for non build-and-test workflow) + skip-if-does-not-exist: true - name: Upload test results if: ${{ !cancelled() }} diff --git a/.github/actions/upload/action.yml b/.github/actions/upload/action.yml index edcece7d2b..8a4cfe2eff 100644 --- a/.github/actions/upload/action.yml +++ b/.github/actions/upload/action.yml @@ -7,6 +7,10 @@ inputs: path: description: "A directory or file to upload" required: true + skip-if-does-not-exist: + description: "Allow to skip if path doesn't exist, fail otherwise" + default: false + required: false prefix: description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'" required: false @@ -15,10 +19,12 @@ runs: using: "composite" steps: - name: Prepare artifact + id: prepare-artifact shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst + SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }} run: | mkdir -p $(dirname $ARCHIVE) @@ -33,14 +39,22 @@ runs: elif [ -f ${SOURCE} ]; then time tar -cf ${ARCHIVE} --zstd ${SOURCE} elif ! ls ${SOURCE} > /dev/null 2>&1; then - echo >&2 "${SOURCE} does not exist" - exit 2 + if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then + echo 'SKIPPED=true' >> $GITHUB_OUTPUT + exit 0 + else + echo >&2 "${SOURCE} does not exist" + exit 2 + fi else echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it" exit 3 fi + echo 'SKIPPED=false' >> $GITHUB_OUTPUT + - name: Upload artifact + if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }} shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} From ab5bbb445bcd76410d884f3431a4dcba3ec8fb37 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 11 Oct 2024 21:14:52 +0200 Subject: [PATCH 18/18] proxy: refactor auth backends (#9271) preliminary for #9270 The auth::Backend didn't need to be in the mega ProxyConfig object, so I split it off and passed it manually in the few places it was necessary. I've also refined some of the uses of config I saw while doing this small refactor. I've also followed the trend and make the console redirect backend it's own struct, same as LocalBackend and ControlPlaneBackend. --- proxy/src/auth/backend/console_redirect.rs | 25 +++- proxy/src/auth/backend/mod.rs | 19 ++- proxy/src/bin/local_proxy.rs | 25 +++- proxy/src/bin/proxy.rs | 154 +++++++++++---------- proxy/src/config.rs | 6 +- proxy/src/proxy/mod.rs | 7 +- proxy/src/serverless/backend.rs | 60 ++++---- proxy/src/serverless/mod.rs | 3 + proxy/src/serverless/sql_over_http.rs | 44 ++---- proxy/src/serverless/websocket.rs | 2 + 10 files changed, 186 insertions(+), 159 deletions(-) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index a7cc678187..127be545e1 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -25,6 +25,10 @@ pub(crate) enum WebAuthError { Io(#[from] std::io::Error), } +pub struct ConsoleRedirectBackend { + console_uri: reqwest::Url, +} + impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() @@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub(super) async fn authenticate( +impl ConsoleRedirectBackend { + pub fn new(console_uri: reqwest::Url) -> Self { + Self { console_uri } + } + + pub(super) fn url(&self) -> &reqwest::Url { + &self.console_uri + } + + pub(crate) async fn authenticate( + &self, + ctx: &RequestMonitoring, + auth_config: &'static AuthenticationConfig, + client: &mut PqStream, + ) -> auth::Result { + authenticate(ctx, auth_config, &self.console_uri, client).await + } +} + +async fn authenticate( ctx: &RequestMonitoring, auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index c9aa5b7e61..27c9f1876e 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -8,6 +8,7 @@ use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::WebAuthError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; @@ -36,7 +37,7 @@ use crate::{ provider::{CachedAllowedIps, CachedNodeInfo}, Api, }, - stream, url, + stream, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; @@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> { /// Cloud API (V2). ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T), /// Authentication via a web browser. - ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D), + ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D), /// Local proxy uses configured auth credentials and does not wake compute Local(MaybeOwned<'a, LocalBackend>), } @@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> { #[cfg(test)] ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), }, - Self::ConsoleRedirect(url, ()) => fmt + Self::ConsoleRedirect(backend, ()) => fmt .debug_tuple("ConsoleRedirect") - .field(&url.as_str()) + .field(&backend.url().as_str()) .finish(), Self::Local(_) => fmt.debug_tuple("Local").finish(), } @@ -241,7 +242,6 @@ impl AuthenticationConfig { pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, is_cleartext: bool, @@ -265,7 +265,7 @@ impl AuthenticationConfig { let limit_not_exceeded = self.rate_limiter.check( ( endpoint_int, - MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet), + MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), ), password_weight, ); @@ -339,7 +339,6 @@ async fn auth_quirks( let secret = if let Some(secret) = secret { config.check_rate_limit( ctx, - config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, @@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { Backend::ControlPlane(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Self::ConsoleRedirect(url, ()) => { + Self::ConsoleRedirect(backend, ()) => { info!("performing web authentication"); - let info = console_redirect::authenticate(ctx, config, &url, client).await?; + let info = backend.authenticate(ctx, config, client).await?; - Backend::ConsoleRedirect(url, info) + Backend::ConsoleRedirect(backend, info) } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index ae8a7f0841..c781af846a 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec; use dashmap::DashMap; use futures::future::Either; use proxy::{ - auth::backend::{ - jwt::JwkCache, - local::{LocalBackend, JWKS_ROLE_MAP}, + auth::{ + self, + backend::{ + jwt::JwkCache, + local::{LocalBackend, JWKS_ROLE_MAP}, + }, }, cancellation::CancellationHandlerMain, config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, @@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> { let args = LocalProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; // before we bind to any ports, write the process ID to a file // so that compute-ctl can find our process later @@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> { let task = serverless::task_main( config, + auth_backend, http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( @@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, - auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), - )), metric_collection: None, allow_self_signed_compute: false, http_config, @@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }))) } +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &LocalProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { + let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( + LocalBackend::new(args.compute), + )); + + Ok(Box::leak(Box::new(auth_backend))) +} + async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { loop { rx.notified().await; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 7488cce3c4..3f4c2df809 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -10,6 +10,7 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::AuthRateLimiter; +use proxy::auth::backend::ConsoleRedirectBackend; use proxy::auth::backend::MaybeOwned; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; @@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> { let args = ProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; - info!("Authentication backend: {}", config.auth_backend); + info!("Authentication backend: {}", auth_backend); info!("Using region: {}", args.aws_region); let region_provider = @@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> { if let Some(proxy_listener) = proxy_listener { client_tasks.spawn(proxy::proxy::task_main( config, + auth_backend, proxy_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> { if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, + auth_backend, serverless_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::Backend::ControlPlane(api, _) = &config.auth_backend { + if let auth::Backend::ControlPlane(api, _) = auth_backend { if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} @@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { bail!("dynamic rate limiter should be disabled"); } + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.connect_compute_lock.parse()?; + info!( + ?limiter, + shards, + ?epoch, + "Using NodeLocks (connect_compute)" + ); + let connect_compute_locks = control_plane::locks::ApiLocks::new( + "connect_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().proxy.connect_compute_lock, + )?; + + let http_config = HttpConfig { + accept_websockets: !args.is_auth_broker, + pool_options: GlobalConnPoolOptions { + max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, + gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, + pool_shards: args.sql_over_http.sql_over_http_pool_shards, + idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, + opt_in: args.sql_over_http.sql_over_http_pool_opt_in, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, + }, + cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), + client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, + max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, + }; + let authentication_config = AuthenticationConfig { + jwks_cache: JwkCache::default(), + thread_pool, + scram_protocol_timeout: args.scram_protocol_timeout, + rate_limiter_enabled: args.auth_rate_limit_enabled, + rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, + ip_allowlist_check_enabled: !args.is_private_access_proxy, + is_auth_broker: args.is_auth_broker, + accept_jwts: args.is_auth_broker, + webauth_confirmation_timeout: args.webauth_confirmation_timeout, + }; + + let config = Box::leak(Box::new(ProxyConfig { + tls_config, + metric_collection, + allow_self_signed_compute: args.allow_self_signed_compute, + http_config, + authentication_config, + proxy_protocol_v2: args.proxy_protocol_v2, + handshake_timeout: args.handshake_timeout, + region: args.region.clone(), + wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, + connect_compute_locks, + connect_to_compute_retry_config: config::RetryConfig::parse( + &args.connect_to_compute_retry, + )?, + })); + + tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); + + Ok(config) +} + +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &ProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { let auth_backend = match &args.auth_backend { AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; @@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { AuthBackendType::Web => { let url = args.uri.parse()?; - auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ()) + auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ()) } #[cfg(feature = "testing")] @@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } }; - let config::ConcurrencyLockOptions { - shards, - limiter, - epoch, - timeout, - } = args.connect_compute_lock.parse()?; - info!( - ?limiter, - shards, - ?epoch, - "Using NodeLocks (connect_compute)" - ); - let connect_compute_locks = control_plane::locks::ApiLocks::new( - "connect_compute_lock", - limiter, - shards, - timeout, - epoch, - &Metrics::get().proxy.connect_compute_lock, - )?; - - let http_config = HttpConfig { - accept_websockets: !args.is_auth_broker, - pool_options: GlobalConnPoolOptions { - max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, - gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, - pool_shards: args.sql_over_http.sql_over_http_pool_shards, - idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, - opt_in: args.sql_over_http.sql_over_http_pool_opt_in, - max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, - }, - cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), - client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, - max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, - max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, - }; - let authentication_config = AuthenticationConfig { - jwks_cache: JwkCache::default(), - thread_pool, - scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, - ip_allowlist_check_enabled: !args.is_private_access_proxy, - is_auth_broker: args.is_auth_broker, - accept_jwts: args.is_auth_broker, - webauth_confirmation_timeout: args.webauth_confirmation_timeout, - }; - - let config = Box::leak(Box::new(ProxyConfig { - tls_config, - auth_backend, - metric_collection, - allow_self_signed_compute: args.allow_self_signed_compute, - http_config, - authentication_config, - proxy_protocol_v2: args.proxy_protocol_v2, - handshake_timeout: args.handshake_timeout, - region: args.region.clone(), - wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, - connect_compute_locks, - connect_to_compute_retry_config: config::RetryConfig::parse( - &args.connect_to_compute_retry, - )?, - })); - - tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); - - Ok(config) + Ok(Box::leak(Box::new(auth_backend))) } #[cfg(test)] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 55d0b6374c..c068fc50fb 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,8 +1,5 @@ use crate::{ - auth::{ - self, - backend::{jwt::JwkCache, AuthRateLimiter}, - }, + auth::backend::{jwt::JwkCache, AuthRateLimiter}, control_plane::locks::ApiLocks, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, scram::threadpool::ThreadPool, @@ -29,7 +26,6 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 9e1af88f41..3a43ccb74a 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -61,6 +61,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -129,6 +130,7 @@ pub async fn task_main( let startup = Box::pin( handle_client( config, + auth_backend, &ctx, cancellation_handler, socket, @@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, @@ -285,8 +289,7 @@ pub(crate) async fn handle_client( let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. - let result = config - .auth_backend + let result = auth_backend .as_ref() .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index f54476b51d..9e49478cf3 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -13,7 +13,7 @@ use crate::{ check_peer_addr_is_in_list, AuthError, }, compute, - config::{AuthenticationConfig, ProxyConfig}, + config::ProxyConfig, context::RequestMonitoring, control_plane::{ errors::{GetAuthInfoError, WakeComputeError}, @@ -42,6 +42,7 @@ pub(crate) struct PoolingBackend { pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, + pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>, pub(crate) endpoint_rate_limiter: Arc, } @@ -49,18 +50,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, password: &[u8], ) -> Result { let user_info = user_info.clone(); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| user_info.clone()); + let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if config.ip_allowlist_check_enabled + if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); @@ -79,7 +75,6 @@ impl PoolingBackend { let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, - config, secret, &user_info.endpoint, true, @@ -91,9 +86,13 @@ impl PoolingBackend { } }; let ep = EndpointIdInt::from(&user_info.endpoint); - let auth_outcome = - crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) - .await?; + let auth_outcome = crate::auth::validate_password_and_exchange( + &self.config.authentication_config.thread_pool, + ep, + password, + secret, + ) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -113,13 +112,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, jwt: String, ) -> Result { - match &self.config.auth_backend { + match &self.auth_backend { crate::auth::Backend::ControlPlane(console, ()) => { - config + self.config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -140,7 +139,9 @@ impl PoolingBackend { "JWT login over web auth proxy is not supported", )), crate::auth::Backend::Local(_) => { - let keys = config + let keys = self + .config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -185,7 +186,7 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.config.auth_backend.as_ref().map(|()| keys); + let backend = self.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -217,21 +218,14 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| ComputeCredentials { - info: ComputeUserInfo { - user: conn_info.user_info.user.clone(), - endpoint: EndpointId::from(format!( - "{}-local-proxy", - conn_info.user_info.endpoint - )), - options: conn_info.user_info.options.clone(), - }, - keys: crate::auth::backend::ComputeCredentialKeys::None, - }); + let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { + info: ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), + options: conn_info.user_info.options.clone(), + }, + keys: crate::auth::backend::ComputeCredentialKeys::None, + }); crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { @@ -269,7 +263,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = match &self.config.auth_backend { + let mut node_info = match &self.auth_backend { auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => { unreachable!("only local_proxy can connect to local postgres") } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index b5820b0535..95f64e972c 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -55,6 +55,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ws_listener: TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -110,6 +111,7 @@ pub async fn task_main( local_pool, pool: Arc::clone(&conn_pool), config, + auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); let tls_acceptor: Arc = match config.tls_config.as_ref() { @@ -397,6 +399,7 @@ async fn request_handler( async move { if let Err(e) = websocket::serve_websocket( config, + backend.auth_backend, ctx, websocket, cancellation_handler, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 646e7f8a52..cf3324926c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; use crate::config::AuthenticationConfig; +use crate::config::HttpConfig; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -554,7 +555,7 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await + handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await } auth => { handle_db_inner( @@ -622,28 +623,17 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { - let is_local_proxy = - matches!(backend.config.auth_backend, crate::auth::Backend::Local(_)); + let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_)); let keys = match auth { AuthData::Password(pw) => { backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.user_info, - &pw, - ) + .authenticate_with_password(ctx, &conn_info.user_info, &pw) .await? } AuthData::Jwt(jwt) => { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await? } }; @@ -691,7 +681,7 @@ async fn handle_db_inner( // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { - stmt.process(config, cancel, &mut client, parsed_headers) + stmt.process(&config.http_config, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { @@ -709,7 +699,7 @@ async fn handle_db_inner( } statements - .process(config, cancel, &mut client, parsed_headers) + .process(&config.http_config, cancel, &mut client, parsed_headers) .await? } }; @@ -749,7 +739,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[ ]; async fn handle_auth_broker_inner( - config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, @@ -757,12 +746,7 @@ async fn handle_auth_broker_inner( backend: Arc, ) -> Result>, SqlOverHttpError> { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await .map_err(HttpConnError::from)?; @@ -800,7 +784,7 @@ async fn handle_auth_broker_inner( impl QueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -874,7 +858,7 @@ impl QueryData { impl BatchQueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -944,7 +928,7 @@ impl BatchQueryData { } async fn query_batch( - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -983,7 +967,7 @@ async fn query_batch( } async fn query_to_json( - config: &'static ProxyConfig, + config: &'static HttpConfig, client: &T, data: QueryData, current_size: &mut usize, @@ -1004,9 +988,9 @@ async fn query_to_json( rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) - if *current_size > config.http_config.max_response_size_bytes { + if *current_size > config.max_response_size_bytes { return Err(SqlOverHttpError::ResponseTooLarge( - config.http_config.max_response_size_bytes, + config.max_response_size_bytes, )); } } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 08d5da9bef..fd0f0cac7f 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -129,6 +129,7 @@ impl AsyncBufRead for WebSocketRw { pub(crate) async fn serve_websocket( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc, @@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket( let res = Box::pin(handle_client( config, + auth_backend, &ctx, cancellation_handler, WebSocketRw::new(websocket),