Refactor connection option handling in python tests

The PgProtocol.connect() function took extra options for username,
database, etc. Remove those options, and have a generic way for each
subclass of PgProtocol to provide some default options, with the
capability override them in the connect() call.
This commit is contained in:
Heikki Linnakangas
2022-04-14 13:31:40 +03:00
parent 19954dfd8a
commit a009fe912a
7 changed files with 69 additions and 86 deletions

View File

@@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv):
pg2 = env.postgres.create_start('test_createuser2') pg2 = env.postgres.create_start('test_createuser2')
# Test that you can connect to new branch as a new user # Test that you can connect to new branch as a new user
assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )] assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )]

View File

@@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str)
copy_input = repeat_bytes(buf.read(), 5000) copy_input = repeat_bytes(buf.read(), 5000)
pg_conn = await pg.connect_async() pg_conn = await pg.connect_async()
# PgProtocol.connect_async sets statement_timeout to 2 minutes.
# That's not enough for this test, on a slow system in debug mode.
await pg_conn.execute("SET statement_timeout='300s'")
await pg_conn.copy_to_table(table_name, source=copy_input) await pg_conn.copy_to_table(table_name, source=copy_input)

View File

@@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol):
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
listen_addr: str, listen_addr: str,
port: int): port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin') super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres')
self.pgdata_dir: str = pgdata_dir self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin self.pg_bin: PgBin = pg_bin

View File

@@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys
] ]
env_vars = { env_vars = {
'PGPORT': str(pg.port), 'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.username, 'PGUSER': pg.default_options['user'],
'PGHOST': pg.host, 'PGHOST': pg.default_options['host'],
} }
# Run the command. # Run the command.

View File

@@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin,
] ]
env_vars = { env_vars = {
'PGPORT': str(pg.port), 'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.username, 'PGUSER': pg.default_options['user'],
'PGHOST': pg.host, 'PGHOST': pg.default_options['host'],
} }
# Run the command. # Run the command.

View File

@@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c
log.info(pg_regress_command) log.info(pg_regress_command)
env_vars = { env_vars = {
'PGPORT': str(pg.port), 'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.username, 'PGUSER': pg.default_options['user'],
'PGHOST': pg.host, 'PGHOST': pg.default_options['host'],
} }
# Run the command. # Run the command.

View File

@@ -27,6 +27,7 @@ from dataclasses import dataclass
# Type-related stuff # Type-related stuff
from psycopg2.extensions import connection as PgConnection from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal from typing_extensions import Literal
@@ -238,98 +239,69 @@ def port_distributor(worker_base_port):
class PgProtocol: class PgProtocol:
""" Reusable connection logic """ """ Reusable connection logic """
def __init__(self, def __init__(self, **kwargs):
host: str, self.default_options = kwargs
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname
self.schema = schema
def connstr(self, def connstr(self, **kwargs) -> str:
*,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
statement_timeout_ms: Optional[int] = None) -> str:
""" """
Build a libpq connection string for the Postgres instance. Build a libpq connection string for the Postgres instance.
""" """
return str(make_dsn(**self.conn_options(**kwargs)))
username = username or self.username def conn_options(self, **kwargs):
password = password or self.password conn_options = self.default_options.copy()
dbname = dbname or self.dbname or "postgres" if 'dsn' in kwargs:
schema = schema or self.schema conn_options.update(parse_dsn(kwargs['dsn']))
res = f'host={self.host} port={self.port} dbname={dbname}' conn_options.update(kwargs)
if username: # Individual statement timeout in seconds. 2 minutes should be
res = f'{res} user={username}' # enough for our tests, but if you need a longer, you can
# change it by calling "SET statement_timeout" after
if password: # connecting.
res = f'{res} password={password}' if 'options' in conn_options:
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
if schema: else:
res = f"{res} options='-c search_path={schema}'" conn_options['options'] = "-cstatement_timeout=120s"
return conn_options
if statement_timeout_ms:
res = f"{res} options='-c statement_timeout={statement_timeout_ms}'"
return res
# autocommit=True here by default because that's what we need most of the time # autocommit=True here by default because that's what we need most of the time
def connect( def connect(self, autocommit=True, **kwargs) -> PgConnection:
self,
*,
autocommit=True,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
# individual statement timeout in seconds, 2 minutes should be enough for our tests
statement_timeout: Optional[int] = 120
) -> PgConnection:
""" """
Connect to the node. Connect to the node.
Returns psycopg2's connection object. Returns psycopg2's connection object.
This method passes all extra params to connstr. This method passes all extra params to connstr.
""" """
conn = psycopg2.connect(**self.conn_options(**kwargs))
conn = psycopg2.connect(
self.connstr(dbname=dbname,
schema=schema,
username=username,
password=password,
statement_timeout_ms=statement_timeout *
1000 if statement_timeout else None))
# WARNING: this setting affects *all* tests! # WARNING: this setting affects *all* tests!
conn.autocommit = autocommit conn.autocommit = autocommit
return conn return conn
async def connect_async(self, async def connect_async(self, **kwargs) -> asyncpg.Connection:
*,
dbname: str = 'postgres',
username: Optional[str] = None,
password: Optional[str] = None) -> asyncpg.Connection:
""" """
Connect to the node from async python. Connect to the node from async python.
Returns asyncpg's connection object. Returns asyncpg's connection object.
""" """
conn = await asyncpg.connect( # asyncpg takes slightly different options than psycopg2. Try
host=self.host, # to convert the defaults from the psycopg2 format.
port=self.port,
database=dbname, # The psycopg2 option 'dbname' is called 'database' is asyncpg
user=username or self.username, conn_options = self.conn_options(**kwargs)
password=password, if 'dbname' in conn_options:
) conn_options['database'] = conn_options.pop('dbname')
return conn
# Convert options='-c<key>=<val>' to server_settings
if 'options' in conn_options:
options = conn_options.pop('options')
for match in re.finditer('-c(\w*)=(\w*)', options):
key = match.group(1)
val = match.group(2)
if 'server_options' in conn_options:
conn_options['server_settings'].update({key: val})
else:
conn_options['server_settings'] = {key: val}
return await asyncpg.connect(**conn_options)
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]: def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
""" """
@@ -1149,10 +1121,10 @@ class ZenithPageserver(PgProtocol):
port: PageserverPort, port: PageserverPort,
remote_storage: Optional[RemoteStorage] = None, remote_storage: Optional[RemoteStorage] = None,
config_override: Optional[str] = None): config_override: Optional[str] = None):
super().__init__(host='localhost', port=port.pg, username='zenith_admin') super().__init__(host='localhost', port=port.pg, user='zenith_admin')
self.env = env self.env = env
self.running = False self.running = False
self.service_port = port # do not shadow PgProtocol.port which is just int self.service_port = port
self.remote_storage = remote_storage self.remote_storage = remote_storage
self.config_override = config_override self.config_override = config_override
@@ -1291,7 +1263,7 @@ def pg_bin(test_output_dir: str) -> PgBin:
class VanillaPostgres(PgProtocol): class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int): def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
super().__init__(host='localhost', port=port) super().__init__(host='localhost', port=port, dbname='postgres')
self.pgdatadir = pgdatadir self.pgdatadir = pgdatadir
self.pg_bin = pg_bin self.pg_bin = pg_bin
self.running = False self.running = False
@@ -1335,8 +1307,14 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
class ZenithProxy(PgProtocol): class ZenithProxy(PgProtocol):
def __init__(self, port: int): def __init__(self, port: int):
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port) super().__init__(host="127.0.0.1",
user="pytest",
password="pytest",
port=port,
dbname='postgres')
self.http_port = 7001 self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self._popen: Optional[subprocess.Popen[bytes]] = None self._popen: Optional[subprocess.Popen[bytes]] = None
def start_static(self, addr="127.0.0.1:5432") -> None: def start_static(self, addr="127.0.0.1:5432") -> None:
@@ -1380,13 +1358,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
class Postgres(PgProtocol): class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """ """ An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int): def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin') super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres')
self.env = env self.env = env
self.running = False self.running = False
self.node_name: Optional[str] = None # dubious, see asserts below self.node_name: Optional[str] = None # dubious, see asserts below
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.port = port
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf # path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
def create( def create(