Improve type safety according to pyright

Pyright found many issues that mypy doesn't seem to want to catch or
mypy isn't configured to catch.

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2024-11-08 14:43:15 -06:00
committed by GitHub
parent af8238ae52
commit ecde8d7632
16 changed files with 67 additions and 31 deletions

View File

@@ -80,7 +80,13 @@ class PgBenchRunResult:
):
stdout_lines = stdout.splitlines()
number_of_clients = 0
number_of_threads = 0
number_of_transactions_actually_processed = 0
latency_average = 0.0
latency_stddev = None
tps = 0.0
scale = 0
# we know significant parts of these values from test input
# but to be precise take them from output

View File

@@ -8,7 +8,7 @@ from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, final
import pytest
from _pytest.fixtures import FixtureRequest
@@ -70,12 +70,12 @@ class PgCompare(ABC):
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name: str):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name: str):
def record_duration(self, out_name: str) -> Iterator[None]:
pass
@contextmanager
@@ -105,6 +105,7 @@ class PgCompare(ABC):
return results
@final
class NeonCompare(PgCompare):
"""PgCompare interface for the neon stack."""
@@ -206,6 +207,7 @@ class NeonCompare(PgCompare):
return self.zenbenchmark.record_duration(out_name)
@final
class VanillaCompare(PgCompare):
"""PgCompare interface for vanilla postgres."""
@@ -271,6 +273,7 @@ class VanillaCompare(PgCompare):
return self.zenbenchmark.record_duration(out_name)
@final
class RemoteCompare(PgCompare):
"""PgCompare interface for a remote postgres instance."""

View File

@@ -4,11 +4,14 @@ https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html
auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :)
"""
from __future__ import annotations
import asyncio
import collections
import io
import json
from collections.abc import AsyncIterable
from typing import TYPE_CHECKING, final
import pytest_asyncio
from h2.config import H2Configuration
@@ -25,34 +28,45 @@ from h2.events import (
)
from h2.exceptions import ProtocolError, StreamClosedError
from h2.settings import SettingCodes
from typing_extensions import override
if TYPE_CHECKING:
from typing import Any, Optional
RequestData = collections.namedtuple("RequestData", ["headers", "data"])
@final
class H2Server:
def __init__(self, host, port) -> None:
def __init__(self, host: str, port: int) -> None:
self.host = host
self.port = port
@final
class H2Protocol(asyncio.Protocol):
def __init__(self):
config = H2Configuration(client_side=False, header_encoding="utf-8")
self.conn = H2Connection(config=config)
self.transport = None
self.stream_data = {}
self.flow_control_futures = {}
self.transport: Optional[asyncio.Transport] = None
self.stream_data: dict[int, RequestData] = {}
self.flow_control_futures: dict[int, asyncio.Future[Any]] = {}
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
@override
def connection_made(self, transport: asyncio.BaseTransport):
assert isinstance(transport, asyncio.Transport)
self.transport = transport
self.conn.initiate_connection()
self.transport.write(self.conn.data_to_send())
def connection_lost(self, _exc):
@override
def connection_lost(self, exc: Optional[Exception]):
for future in self.flow_control_futures.values():
future.cancel()
self.flow_control_futures = {}
@override
def data_received(self, data: bytes):
assert self.transport is not None
try:
@@ -77,7 +91,7 @@ class H2Protocol(asyncio.Protocol):
self.window_updated(event.stream_id, event.delta)
elif isinstance(event, RemoteSettingsChanged):
if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
self.window_updated(None, 0)
self.window_updated(0, 0)
self.transport.write(self.conn.data_to_send())
@@ -123,7 +137,7 @@ class H2Protocol(asyncio.Protocol):
else:
stream_data.data.write(data)
def stream_reset(self, stream_id):
def stream_reset(self, stream_id: int):
"""
A stream reset was sent. Stop sending data.
"""
@@ -131,7 +145,7 @@ class H2Protocol(asyncio.Protocol):
future = self.flow_control_futures.pop(stream_id)
future.cancel()
async def send_data(self, data, stream_id):
async def send_data(self, data: bytes, stream_id: int):
"""
Send data according to the flow control rules.
"""
@@ -161,7 +175,7 @@ class H2Protocol(asyncio.Protocol):
self.transport.write(self.conn.data_to_send())
data = data[chunk_size:]
async def wait_for_flow_control(self, stream_id):
async def wait_for_flow_control(self, stream_id: int):
"""
Waits for a Future that fires when the flow control window is opened.
"""
@@ -169,7 +183,7 @@ class H2Protocol(asyncio.Protocol):
self.flow_control_futures[stream_id] = f
await f
def window_updated(self, stream_id, delta):
def window_updated(self, stream_id: int, delta):
"""
A window update frame was received. Unblock some number of flow control
Futures.

View File

@@ -1857,7 +1857,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
shard_count: Optional[int] = None,
shard_stripe_size: Optional[int] = None,
tenant_config: Optional[dict[Any, Any]] = None,
placement_policy: Optional[Union[dict[Any, Any] | str]] = None,
placement_policy: Optional[Union[dict[Any, Any], str]] = None,
):
"""
Use this rather than pageserver_api() when you need to include shard parameters

View File

@@ -316,7 +316,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def tenant_location_conf(
self,
tenant_id: Union[TenantId, TenantShardId],
location_conf=dict[str, Any],
location_conf: dict[str, Any],
flush_ms=None,
lazy: Optional[bool] = None,
):

View File

@@ -56,6 +56,8 @@ def wait_for_upload(
lsn: Lsn,
):
"""waits for local timeline upload up to specified lsn"""
current_lsn = Lsn(0)
for i in range(20):
current_lsn = remote_consistent_lsn(pageserver_http, tenant, timeline)
if current_lsn >= lsn:
@@ -203,6 +205,8 @@ def wait_for_last_record_lsn(
lsn: Lsn,
) -> Lsn:
"""waits for pageserver to catch up to a certain lsn, returns the last observed lsn."""
current_lsn = Lsn(0)
for i in range(1000):
current_lsn = last_record_lsn(pageserver_http, tenant, timeline)
if current_lsn >= lsn:

View File

@@ -112,7 +112,7 @@ def compatibility_snapshot_dir() -> Iterator[Path]:
@pytest.fixture(scope="session")
def compatibility_neon_binpath() -> Optional[Iterator[Path]]:
def compatibility_neon_binpath() -> Iterator[Optional[Path]]:
if os.getenv("REMOTE_ENV"):
return
comp_binpath = None
@@ -133,7 +133,7 @@ def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
@pytest.fixture(scope="session")
def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]:
def compatibility_pg_distrib_dir() -> Iterator[Optional[Path]]:
compat_distrib_dir = None
if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"):
compat_distrib_dir = Path(env_compat_postgres_bin).resolve()

View File

@@ -2,11 +2,13 @@ from __future__ import annotations
from contextlib import closing
from io import BufferedReader, RawIOBase
from typing import Optional
from typing import Optional, final
from fixtures.compare_fixtures import PgCompare
from typing_extensions import override
@final
class CopyTestData(RawIOBase):
def __init__(self, rows: int):
self.rows = rows
@@ -14,6 +16,7 @@ class CopyTestData(RawIOBase):
self.linebuf: Optional[bytes] = None
self.ptr = 0
@override
def readable(self):
return True

View File

@@ -656,6 +656,7 @@ def test_upgrade_generationless_local_file_paths(
workload.write_rows(1000)
attached_pageserver = env.get_tenant_pageserver(tenant_id)
assert attached_pageserver is not None
secondary_pageserver = list([ps for ps in env.pageservers if ps.id != attached_pageserver.id])[
0
]

View File

@@ -37,7 +37,7 @@ async def test_websockets(static_proxy: NeonProxy):
startup_message.extend(b"\0")
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
await websocket.send([length, startup_message])
await websocket.send([length, bytes(startup_message)])
startup_response = await websocket.recv()
assert isinstance(startup_response, bytes)

View File

@@ -256,6 +256,7 @@ def test_sharding_split_compaction(
# Cleanup part 1: while layers are still in PITR window, we should only drop layers that are fully redundant
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
# Invoke compaction: this should drop any layers that don't overlap with the shard's key stripes
detail_before = ps.http_client().timeline_detail(shard, timeline_id)

View File

@@ -1237,6 +1237,7 @@ def test_storage_controller_tenant_deletion(
# Assert attachments all have local content
for shard_id in shard_ids:
pageserver = env.get_tenant_pageserver(shard_id)
assert pageserver is not None
assert pageserver.tenant_dir(shard_id).exists()
# Assert all shards have some content in remote storage
@@ -2745,6 +2746,7 @@ def test_storage_controller_validate_during_migration(neon_env_builder: NeonEnvB
# Upload but don't compact
origin_pageserver = env.get_tenant_pageserver(tenant_id)
assert origin_pageserver is not None
dest_ps_id = [p.id for p in env.pageservers if p.id != origin_pageserver.id][0]
origin_pageserver.http_client().timeline_checkpoint(
tenant_id, timeline_id, wait_until_uploaded=True, compact=False

View File

@@ -245,6 +245,7 @@ def test_scrubber_physical_gc_ancestors(
workload.write_rows(100, upload=False)
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
log.info(f"Waiting for shard {shard} on pageserver {ps.id}")
ps.http_client().timeline_checkpoint(
shard, timeline_id, compact=False, wait_until_uploaded=True
@@ -270,6 +271,7 @@ def test_scrubber_physical_gc_ancestors(
workload.churn_rows(100)
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
ps.http_client().timeline_compact(shard, timeline_id, force_image_layer_creation=True)
ps.http_client().timeline_gc(shard, timeline_id, 0)
@@ -336,12 +338,15 @@ def test_scrubber_physical_gc_timeline_deletion(neon_env_builder: NeonEnvBuilder
# Issue a deletion queue flush so that the parent shard can't leave behind layers
# that will look like unexpected garbage to the scrubber
env.get_tenant_pageserver(tenant_id).http_client().deletion_queue_flush(execute=True)
ps = env.get_tenant_pageserver(tenant_id)
assert ps is not None
ps.http_client().deletion_queue_flush(execute=True)
new_shard_count = 4
shards = env.storage_controller.tenant_shard_split(tenant_id, shard_count=new_shard_count)
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
log.info(f"Waiting for shard {shard} on pageserver {ps.id}")
ps.http_client().timeline_checkpoint(
shard, timeline_id, compact=False, wait_until_uploaded=True

View File

@@ -315,6 +315,7 @@ def test_single_branch_get_tenant_size_grows(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> tuple[Lsn, int]:
size = 0
consistent = False
size_debug = None
@@ -360,7 +361,7 @@ def test_single_branch_get_tenant_size_grows(
collected_responses.append(("CREATE", current_lsn, size))
batch_size = 100
prev_size = 0
for i in range(3):
with endpoint.cursor() as cur:
cur.execute(

View File

@@ -146,6 +146,7 @@ def test_threshold_based_eviction(
out += [f" {remote} {layer.layer_file_name}"]
return "\n".join(out)
stable_for: float = 0
observation_window = 8 * eviction_threshold
consider_stable_when_no_change_for_seconds = 3 * eviction_threshold
poll_interval = eviction_threshold / 3

View File

@@ -1506,15 +1506,10 @@ class SafekeeperEnv:
port=port.http,
auth_token=None,
)
try:
safekeeper_process = start_in_background(
cmd, safekeeper_dir, "safekeeper.log", safekeeper_client.check_status
)
return safekeeper_process
except Exception as e:
log.error(e)
safekeeper_process.kill()
raise Exception(f"Failed to start safekepeer as {cmd}, reason: {e}") from e
safekeeper_process = start_in_background(
cmd, safekeeper_dir, "safekeeper.log", safekeeper_client.check_status
)
return safekeeper_process
def get_safekeeper_connstrs(self):
assert self.safekeepers is not None, "safekeepers are not initialized"