mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
Improve type safety according to pyright
Pyright found many issues that mypy doesn't seem to want to catch or mypy isn't configured to catch. Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
@@ -80,7 +80,13 @@ class PgBenchRunResult:
|
||||
):
|
||||
stdout_lines = stdout.splitlines()
|
||||
|
||||
number_of_clients = 0
|
||||
number_of_threads = 0
|
||||
number_of_transactions_actually_processed = 0
|
||||
latency_average = 0.0
|
||||
latency_stddev = None
|
||||
tps = 0.0
|
||||
scale = 0
|
||||
|
||||
# we know significant parts of these values from test input
|
||||
# but to be precise take them from output
|
||||
|
||||
@@ -8,7 +8,7 @@ from contextlib import _GeneratorContextManager, contextmanager
|
||||
|
||||
# Type-related stuff
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
@@ -70,12 +70,12 @@ class PgCompare(ABC):
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_pageserver_writes(self, out_name: str):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_duration(self, out_name: str):
|
||||
def record_duration(self, out_name: str) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@@ -105,6 +105,7 @@ class PgCompare(ABC):
|
||||
return results
|
||||
|
||||
|
||||
@final
|
||||
class NeonCompare(PgCompare):
|
||||
"""PgCompare interface for the neon stack."""
|
||||
|
||||
@@ -206,6 +207,7 @@ class NeonCompare(PgCompare):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@final
|
||||
class VanillaCompare(PgCompare):
|
||||
"""PgCompare interface for vanilla postgres."""
|
||||
|
||||
@@ -271,6 +273,7 @@ class VanillaCompare(PgCompare):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@final
|
||||
class RemoteCompare(PgCompare):
|
||||
"""PgCompare interface for a remote postgres instance."""
|
||||
|
||||
|
||||
@@ -4,11 +4,14 @@ https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html
|
||||
auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
import pytest_asyncio
|
||||
from h2.config import H2Configuration
|
||||
@@ -25,34 +28,45 @@ from h2.events import (
|
||||
)
|
||||
from h2.exceptions import ProtocolError, StreamClosedError
|
||||
from h2.settings import SettingCodes
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
RequestData = collections.namedtuple("RequestData", ["headers", "data"])
|
||||
|
||||
|
||||
@final
|
||||
class H2Server:
|
||||
def __init__(self, host, port) -> None:
|
||||
def __init__(self, host: str, port: int) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
|
||||
@final
|
||||
class H2Protocol(asyncio.Protocol):
|
||||
def __init__(self):
|
||||
config = H2Configuration(client_side=False, header_encoding="utf-8")
|
||||
self.conn = H2Connection(config=config)
|
||||
self.transport = None
|
||||
self.stream_data = {}
|
||||
self.flow_control_futures = {}
|
||||
self.transport: Optional[asyncio.Transport] = None
|
||||
self.stream_data: dict[int, RequestData] = {}
|
||||
self.flow_control_futures: dict[int, asyncio.Future[Any]] = {}
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
|
||||
@override
|
||||
def connection_made(self, transport: asyncio.BaseTransport):
|
||||
assert isinstance(transport, asyncio.Transport)
|
||||
self.transport = transport
|
||||
self.conn.initiate_connection()
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
def connection_lost(self, _exc):
|
||||
@override
|
||||
def connection_lost(self, exc: Optional[Exception]):
|
||||
for future in self.flow_control_futures.values():
|
||||
future.cancel()
|
||||
self.flow_control_futures = {}
|
||||
|
||||
@override
|
||||
def data_received(self, data: bytes):
|
||||
assert self.transport is not None
|
||||
try:
|
||||
@@ -77,7 +91,7 @@ class H2Protocol(asyncio.Protocol):
|
||||
self.window_updated(event.stream_id, event.delta)
|
||||
elif isinstance(event, RemoteSettingsChanged):
|
||||
if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
|
||||
self.window_updated(None, 0)
|
||||
self.window_updated(0, 0)
|
||||
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
@@ -123,7 +137,7 @@ class H2Protocol(asyncio.Protocol):
|
||||
else:
|
||||
stream_data.data.write(data)
|
||||
|
||||
def stream_reset(self, stream_id):
|
||||
def stream_reset(self, stream_id: int):
|
||||
"""
|
||||
A stream reset was sent. Stop sending data.
|
||||
"""
|
||||
@@ -131,7 +145,7 @@ class H2Protocol(asyncio.Protocol):
|
||||
future = self.flow_control_futures.pop(stream_id)
|
||||
future.cancel()
|
||||
|
||||
async def send_data(self, data, stream_id):
|
||||
async def send_data(self, data: bytes, stream_id: int):
|
||||
"""
|
||||
Send data according to the flow control rules.
|
||||
"""
|
||||
@@ -161,7 +175,7 @@ class H2Protocol(asyncio.Protocol):
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
data = data[chunk_size:]
|
||||
|
||||
async def wait_for_flow_control(self, stream_id):
|
||||
async def wait_for_flow_control(self, stream_id: int):
|
||||
"""
|
||||
Waits for a Future that fires when the flow control window is opened.
|
||||
"""
|
||||
@@ -169,7 +183,7 @@ class H2Protocol(asyncio.Protocol):
|
||||
self.flow_control_futures[stream_id] = f
|
||||
await f
|
||||
|
||||
def window_updated(self, stream_id, delta):
|
||||
def window_updated(self, stream_id: int, delta):
|
||||
"""
|
||||
A window update frame was received. Unblock some number of flow control
|
||||
Futures.
|
||||
|
||||
@@ -1857,7 +1857,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
shard_count: Optional[int] = None,
|
||||
shard_stripe_size: Optional[int] = None,
|
||||
tenant_config: Optional[dict[Any, Any]] = None,
|
||||
placement_policy: Optional[Union[dict[Any, Any] | str]] = None,
|
||||
placement_policy: Optional[Union[dict[Any, Any], str]] = None,
|
||||
):
|
||||
"""
|
||||
Use this rather than pageserver_api() when you need to include shard parameters
|
||||
|
||||
@@ -316,7 +316,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
def tenant_location_conf(
|
||||
self,
|
||||
tenant_id: Union[TenantId, TenantShardId],
|
||||
location_conf=dict[str, Any],
|
||||
location_conf: dict[str, Any],
|
||||
flush_ms=None,
|
||||
lazy: Optional[bool] = None,
|
||||
):
|
||||
|
||||
@@ -56,6 +56,8 @@ def wait_for_upload(
|
||||
lsn: Lsn,
|
||||
):
|
||||
"""waits for local timeline upload up to specified lsn"""
|
||||
|
||||
current_lsn = Lsn(0)
|
||||
for i in range(20):
|
||||
current_lsn = remote_consistent_lsn(pageserver_http, tenant, timeline)
|
||||
if current_lsn >= lsn:
|
||||
@@ -203,6 +205,8 @@ def wait_for_last_record_lsn(
|
||||
lsn: Lsn,
|
||||
) -> Lsn:
|
||||
"""waits for pageserver to catch up to a certain lsn, returns the last observed lsn."""
|
||||
|
||||
current_lsn = Lsn(0)
|
||||
for i in range(1000):
|
||||
current_lsn = last_record_lsn(pageserver_http, tenant, timeline)
|
||||
if current_lsn >= lsn:
|
||||
|
||||
@@ -112,7 +112,7 @@ def compatibility_snapshot_dir() -> Iterator[Path]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compatibility_neon_binpath() -> Optional[Iterator[Path]]:
|
||||
def compatibility_neon_binpath() -> Iterator[Optional[Path]]:
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
return
|
||||
comp_binpath = None
|
||||
@@ -133,7 +133,7 @@ def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]:
|
||||
def compatibility_pg_distrib_dir() -> Iterator[Optional[Path]]:
|
||||
compat_distrib_dir = None
|
||||
if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"):
|
||||
compat_distrib_dir = Path(env_compat_postgres_bin).resolve()
|
||||
|
||||
@@ -2,11 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import closing
|
||||
from io import BufferedReader, RawIOBase
|
||||
from typing import Optional
|
||||
from typing import Optional, final
|
||||
|
||||
from fixtures.compare_fixtures import PgCompare
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@final
|
||||
class CopyTestData(RawIOBase):
|
||||
def __init__(self, rows: int):
|
||||
self.rows = rows
|
||||
@@ -14,6 +16,7 @@ class CopyTestData(RawIOBase):
|
||||
self.linebuf: Optional[bytes] = None
|
||||
self.ptr = 0
|
||||
|
||||
@override
|
||||
def readable(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -656,6 +656,7 @@ def test_upgrade_generationless_local_file_paths(
|
||||
workload.write_rows(1000)
|
||||
|
||||
attached_pageserver = env.get_tenant_pageserver(tenant_id)
|
||||
assert attached_pageserver is not None
|
||||
secondary_pageserver = list([ps for ps in env.pageservers if ps.id != attached_pageserver.id])[
|
||||
0
|
||||
]
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_websockets(static_proxy: NeonProxy):
|
||||
startup_message.extend(b"\0")
|
||||
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
|
||||
|
||||
await websocket.send([length, startup_message])
|
||||
await websocket.send([length, bytes(startup_message)])
|
||||
|
||||
startup_response = await websocket.recv()
|
||||
assert isinstance(startup_response, bytes)
|
||||
|
||||
@@ -256,6 +256,7 @@ def test_sharding_split_compaction(
|
||||
# Cleanup part 1: while layers are still in PITR window, we should only drop layers that are fully redundant
|
||||
for shard in shards:
|
||||
ps = env.get_tenant_pageserver(shard)
|
||||
assert ps is not None
|
||||
|
||||
# Invoke compaction: this should drop any layers that don't overlap with the shard's key stripes
|
||||
detail_before = ps.http_client().timeline_detail(shard, timeline_id)
|
||||
|
||||
@@ -1237,6 +1237,7 @@ def test_storage_controller_tenant_deletion(
|
||||
# Assert attachments all have local content
|
||||
for shard_id in shard_ids:
|
||||
pageserver = env.get_tenant_pageserver(shard_id)
|
||||
assert pageserver is not None
|
||||
assert pageserver.tenant_dir(shard_id).exists()
|
||||
|
||||
# Assert all shards have some content in remote storage
|
||||
@@ -2745,6 +2746,7 @@ def test_storage_controller_validate_during_migration(neon_env_builder: NeonEnvB
|
||||
|
||||
# Upload but don't compact
|
||||
origin_pageserver = env.get_tenant_pageserver(tenant_id)
|
||||
assert origin_pageserver is not None
|
||||
dest_ps_id = [p.id for p in env.pageservers if p.id != origin_pageserver.id][0]
|
||||
origin_pageserver.http_client().timeline_checkpoint(
|
||||
tenant_id, timeline_id, wait_until_uploaded=True, compact=False
|
||||
|
||||
@@ -245,6 +245,7 @@ def test_scrubber_physical_gc_ancestors(
|
||||
workload.write_rows(100, upload=False)
|
||||
for shard in shards:
|
||||
ps = env.get_tenant_pageserver(shard)
|
||||
assert ps is not None
|
||||
log.info(f"Waiting for shard {shard} on pageserver {ps.id}")
|
||||
ps.http_client().timeline_checkpoint(
|
||||
shard, timeline_id, compact=False, wait_until_uploaded=True
|
||||
@@ -270,6 +271,7 @@ def test_scrubber_physical_gc_ancestors(
|
||||
workload.churn_rows(100)
|
||||
for shard in shards:
|
||||
ps = env.get_tenant_pageserver(shard)
|
||||
assert ps is not None
|
||||
ps.http_client().timeline_compact(shard, timeline_id, force_image_layer_creation=True)
|
||||
ps.http_client().timeline_gc(shard, timeline_id, 0)
|
||||
|
||||
@@ -336,12 +338,15 @@ def test_scrubber_physical_gc_timeline_deletion(neon_env_builder: NeonEnvBuilder
|
||||
|
||||
# Issue a deletion queue flush so that the parent shard can't leave behind layers
|
||||
# that will look like unexpected garbage to the scrubber
|
||||
env.get_tenant_pageserver(tenant_id).http_client().deletion_queue_flush(execute=True)
|
||||
ps = env.get_tenant_pageserver(tenant_id)
|
||||
assert ps is not None
|
||||
ps.http_client().deletion_queue_flush(execute=True)
|
||||
|
||||
new_shard_count = 4
|
||||
shards = env.storage_controller.tenant_shard_split(tenant_id, shard_count=new_shard_count)
|
||||
for shard in shards:
|
||||
ps = env.get_tenant_pageserver(shard)
|
||||
assert ps is not None
|
||||
log.info(f"Waiting for shard {shard} on pageserver {ps.id}")
|
||||
ps.http_client().timeline_checkpoint(
|
||||
shard, timeline_id, compact=False, wait_until_uploaded=True
|
||||
|
||||
@@ -315,6 +315,7 @@ def test_single_branch_get_tenant_size_grows(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> tuple[Lsn, int]:
|
||||
size = 0
|
||||
consistent = False
|
||||
size_debug = None
|
||||
|
||||
@@ -360,7 +361,7 @@ def test_single_branch_get_tenant_size_grows(
|
||||
collected_responses.append(("CREATE", current_lsn, size))
|
||||
|
||||
batch_size = 100
|
||||
|
||||
prev_size = 0
|
||||
for i in range(3):
|
||||
with endpoint.cursor() as cur:
|
||||
cur.execute(
|
||||
|
||||
@@ -146,6 +146,7 @@ def test_threshold_based_eviction(
|
||||
out += [f" {remote} {layer.layer_file_name}"]
|
||||
return "\n".join(out)
|
||||
|
||||
stable_for: float = 0
|
||||
observation_window = 8 * eviction_threshold
|
||||
consider_stable_when_no_change_for_seconds = 3 * eviction_threshold
|
||||
poll_interval = eviction_threshold / 3
|
||||
|
||||
@@ -1506,15 +1506,10 @@ class SafekeeperEnv:
|
||||
port=port.http,
|
||||
auth_token=None,
|
||||
)
|
||||
try:
|
||||
safekeeper_process = start_in_background(
|
||||
cmd, safekeeper_dir, "safekeeper.log", safekeeper_client.check_status
|
||||
)
|
||||
return safekeeper_process
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
safekeeper_process.kill()
|
||||
raise Exception(f"Failed to start safekepeer as {cmd}, reason: {e}") from e
|
||||
safekeeper_process = start_in_background(
|
||||
cmd, safekeeper_dir, "safekeeper.log", safekeeper_client.check_status
|
||||
)
|
||||
return safekeeper_process
|
||||
|
||||
def get_safekeeper_connstrs(self):
|
||||
assert self.safekeepers is not None, "safekeepers are not initialized"
|
||||
|
||||
Reference in New Issue
Block a user