Improve typing related to regress/test_logical_replication.py (#9725)

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2024-11-11 17:36:45 -06:00
committed by GitHub
parent b018bc7da8
commit 5be6b07cf1
2 changed files with 33 additions and 21 deletions

View File

@@ -286,7 +286,7 @@ class PgProtocol:
return self.safe_psql_many([query], **kwargs)[0]
def safe_psql_many(
self, queries: Iterable[str], log_query=True, **kwargs: Any
self, queries: Iterable[str], log_query: bool = True, **kwargs: Any
) -> list[list[tuple[Any, ...]]]:
"""
Execute queries against the node and return all rows.
@@ -306,7 +306,7 @@ class PgProtocol:
result.append(cur.fetchall())
return result
def safe_psql_scalar(self, query, log_query=True) -> Any:
def safe_psql_scalar(self, query: str, log_query: bool = True) -> Any:
"""
Execute query returning single row with single column.
"""

View File

@@ -4,24 +4,31 @@ import time
from functools import partial
from random import choice
from string import ascii_lowercase
from typing import TYPE_CHECKING, cast
from fixtures.common_types import Lsn
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
PgProtocol,
logical_replication_sync,
wait_for_last_flush_lsn,
)
from fixtures.utils import wait_until
if TYPE_CHECKING:
from fixtures.neon_fixtures import (
Endpoint,
NeonEnv,
NeonEnvBuilder,
PgProtocol,
VanillaPostgres,
)
def random_string(n: int):
return "".join([choice(ascii_lowercase) for _ in range(n)])
def test_logical_replication(neon_simple_env: NeonEnv, vanilla_pg):
def test_logical_replication(neon_simple_env: NeonEnv, vanilla_pg: VanillaPostgres):
env = neon_simple_env
tenant_id = env.initial_tenant
@@ -160,10 +167,10 @@ COMMIT;
# Test that neon.logical_replication_max_snap_files works
def test_obsolete_slot_drop(neon_simple_env: NeonEnv, vanilla_pg):
def slot_removed(ep):
def test_obsolete_slot_drop(neon_simple_env: NeonEnv, vanilla_pg: VanillaPostgres):
def slot_removed(ep: Endpoint):
assert (
endpoint.safe_psql(
ep.safe_psql(
"select count(*) from pg_replication_slots where slot_name = 'stale_slot'"
)[0][0]
== 0
@@ -254,7 +261,7 @@ FROM generate_series(1, 16384) AS seq; -- Inserts enough rows to exceed 16MB of
# Tests that walsender correctly blocks until WAL is downloaded from safekeepers
def test_lr_with_slow_safekeeper(neon_env_builder: NeonEnvBuilder, vanilla_pg):
def test_lr_with_slow_safekeeper(neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres):
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
@@ -336,13 +343,13 @@ FROM generate_series(1, 16384) AS seq; -- Inserts enough rows to exceed 16MB of
#
# Most pages start with a contrecord, so we don't do anything special
# to ensure that.
def test_restart_endpoint(neon_simple_env: NeonEnv, vanilla_pg):
def test_restart_endpoint(neon_simple_env: NeonEnv, vanilla_pg: VanillaPostgres):
env = neon_simple_env
env.create_branch("init")
endpoint = env.endpoints.create_start("init")
tenant_id = endpoint.safe_psql("show neon.tenant_id")[0][0]
timeline_id = endpoint.safe_psql("show neon.timeline_id")[0][0]
tenant_id = TenantId(cast("str", endpoint.safe_psql("show neon.tenant_id")[0][0]))
timeline_id = TimelineId(cast("str", endpoint.safe_psql("show neon.timeline_id")[0][0]))
cur = endpoint.connect().cursor()
cur.execute("create table t(key int, value text)")
@@ -380,7 +387,7 @@ def test_restart_endpoint(neon_simple_env: NeonEnv, vanilla_pg):
# logical replication bug as such, but without logical replication,
# records passed ot the WAL redo process are never large enough to hit
# the bug.
def test_large_records(neon_simple_env: NeonEnv, vanilla_pg):
def test_large_records(neon_simple_env: NeonEnv, vanilla_pg: VanillaPostgres):
env = neon_simple_env
env.create_branch("init")
@@ -522,15 +529,20 @@ def logical_replication_wait_flush_lsn_sync(publisher: PgProtocol) -> Lsn:
because for some WAL records like vacuum subscriber won't get any data at
all.
"""
publisher_flush_lsn = Lsn(publisher.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0])
publisher_flush_lsn = Lsn(
cast("str", publisher.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0])
)
def check_caughtup():
res = publisher.safe_psql(
"""
res = cast(
"tuple[str, str, str]",
publisher.safe_psql(
"""
select sent_lsn, flush_lsn, pg_current_wal_flush_lsn() from pg_stat_replication sr, pg_replication_slots s
where s.active_pid = sr.pid and s.slot_type = 'logical';
"""
)[0]
)[0],
)
sent_lsn, flush_lsn, curr_publisher_flush_lsn = Lsn(res[0]), Lsn(res[1]), Lsn(res[2])
log.info(
f"sent_lsn={sent_lsn}, flush_lsn={flush_lsn}, publisher_flush_lsn={curr_publisher_flush_lsn}, waiting flush_lsn to reach {publisher_flush_lsn}"
@@ -545,7 +557,7 @@ select sent_lsn, flush_lsn, pg_current_wal_flush_lsn() from pg_stat_replication
# flush_lsn reporting to publisher. Without this, subscriber may ack too far,
# losing data on restart because publisher implicitly advances positition given
# in START_REPLICATION to the confirmed_flush_lsn of the slot.
def test_subscriber_synchronous_commit(neon_simple_env: NeonEnv, vanilla_pg):
def test_subscriber_synchronous_commit(neon_simple_env: NeonEnv, vanilla_pg: VanillaPostgres):
env = neon_simple_env
# use vanilla as publisher to allow writes on it when safekeeper is down
vanilla_pg.configure(
@@ -593,7 +605,7 @@ def test_subscriber_synchronous_commit(neon_simple_env: NeonEnv, vanilla_pg):
# logical_replication_wait_flush_lsn_sync is expected to hang while
# safekeeper is down.
vanilla_pg.safe_psql("checkpoint;")
assert sub.safe_psql_scalar("SELECT count(*) FROM t") == 1000
assert cast("int", sub.safe_psql_scalar("SELECT count(*) FROM t")) == 1000
# restart subscriber and ensure it can catch up lost tail again
sub.stop(mode="immediate")