Refactor connection option handling in python tests

The PgProtocol.connect() function took extra options for username,
database, etc. Remove those options, and have a generic way for each
subclass of PgProtocol to provide some default options, with the
capability override them in the connect() call.
This commit is contained in:
Heikki Linnakangas
2022-04-14 13:31:40 +03:00
parent 19954dfd8a
commit a009fe912a
7 changed files with 69 additions and 86 deletions

View File

@@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv):
pg2 = env.postgres.create_start('test_createuser2')
# Test that you can connect to new branch as a new user
assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )]
assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )]

View File

@@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str)
copy_input = repeat_bytes(buf.read(), 5000)
pg_conn = await pg.connect_async()
# PgProtocol.connect_async sets statement_timeout to 2 minutes.
# That's not enough for this test, on a slow system in debug mode.
await pg_conn.execute("SET statement_timeout='300s'")
await pg_conn.copy_to_table(table_name, source=copy_input)

View File

@@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol):
tenant_id: uuid.UUID,
listen_addr: str,
port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin')
super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres')
self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin

View File

@@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys
]
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin,
]
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c
log.info(pg_regress_command)
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -27,6 +27,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal
@@ -238,98 +239,69 @@ def port_distributor(worker_base_port):
class PgProtocol:
""" Reusable connection logic """
def __init__(self,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname
self.schema = schema
def __init__(self, **kwargs):
self.default_options = kwargs
def connstr(self,
*,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
statement_timeout_ms: Optional[int] = None) -> str:
def connstr(self, **kwargs) -> str:
"""
Build a libpq connection string for the Postgres instance.
"""
return str(make_dsn(**self.conn_options(**kwargs)))
username = username or self.username
password = password or self.password
dbname = dbname or self.dbname or "postgres"
schema = schema or self.schema
res = f'host={self.host} port={self.port} dbname={dbname}'
def conn_options(self, **kwargs):
conn_options = self.default_options.copy()
if 'dsn' in kwargs:
conn_options.update(parse_dsn(kwargs['dsn']))
conn_options.update(kwargs)
if username:
res = f'{res} user={username}'
if password:
res = f'{res} password={password}'
if schema:
res = f"{res} options='-c search_path={schema}'"
if statement_timeout_ms:
res = f"{res} options='-c statement_timeout={statement_timeout_ms}'"
return res
# Individual statement timeout in seconds. 2 minutes should be
# enough for our tests, but if you need a longer, you can
# change it by calling "SET statement_timeout" after
# connecting.
if 'options' in conn_options:
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
else:
conn_options['options'] = "-cstatement_timeout=120s"
return conn_options
# autocommit=True here by default because that's what we need most of the time
def connect(
self,
*,
autocommit=True,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
# individual statement timeout in seconds, 2 minutes should be enough for our tests
statement_timeout: Optional[int] = 120
) -> PgConnection:
def connect(self, autocommit=True, **kwargs) -> PgConnection:
"""
Connect to the node.
Returns psycopg2's connection object.
This method passes all extra params to connstr.
"""
conn = psycopg2.connect(**self.conn_options(**kwargs))
conn = psycopg2.connect(
self.connstr(dbname=dbname,
schema=schema,
username=username,
password=password,
statement_timeout_ms=statement_timeout *
1000 if statement_timeout else None))
# WARNING: this setting affects *all* tests!
conn.autocommit = autocommit
return conn
async def connect_async(self,
*,
dbname: str = 'postgres',
username: Optional[str] = None,
password: Optional[str] = None) -> asyncpg.Connection:
async def connect_async(self, **kwargs) -> asyncpg.Connection:
"""
Connect to the node from async python.
Returns asyncpg's connection object.
"""
conn = await asyncpg.connect(
host=self.host,
port=self.port,
database=dbname,
user=username or self.username,
password=password,
)
return conn
# asyncpg takes slightly different options than psycopg2. Try
# to convert the defaults from the psycopg2 format.
# The psycopg2 option 'dbname' is called 'database' is asyncpg
conn_options = self.conn_options(**kwargs)
if 'dbname' in conn_options:
conn_options['database'] = conn_options.pop('dbname')
# Convert options='-c<key>=<val>' to server_settings
if 'options' in conn_options:
options = conn_options.pop('options')
for match in re.finditer('-c(\w*)=(\w*)', options):
key = match.group(1)
val = match.group(2)
if 'server_options' in conn_options:
conn_options['server_settings'].update({key: val})
else:
conn_options['server_settings'] = {key: val}
return await asyncpg.connect(**conn_options)
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
"""
@@ -1149,10 +1121,10 @@ class ZenithPageserver(PgProtocol):
port: PageserverPort,
remote_storage: Optional[RemoteStorage] = None,
config_override: Optional[str] = None):
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
super().__init__(host='localhost', port=port.pg, user='zenith_admin')
self.env = env
self.running = False
self.service_port = port # do not shadow PgProtocol.port which is just int
self.service_port = port
self.remote_storage = remote_storage
self.config_override = config_override
@@ -1291,7 +1263,7 @@ def pg_bin(test_output_dir: str) -> PgBin:
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
super().__init__(host='localhost', port=port)
super().__init__(host='localhost', port=port, dbname='postgres')
self.pgdatadir = pgdatadir
self.pg_bin = pg_bin
self.running = False
@@ -1335,8 +1307,14 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
class ZenithProxy(PgProtocol):
def __init__(self, port: int):
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
super().__init__(host="127.0.0.1",
user="pytest",
password="pytest",
port=port,
dbname='postgres')
self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self._popen: Optional[subprocess.Popen[bytes]] = None
def start_static(self, addr="127.0.0.1:5432") -> None:
@@ -1380,13 +1358,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin')
super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres')
self.env = env
self.running = False
self.node_name: Optional[str] = None # dubious, see asserts below
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
self.tenant_id = tenant_id
self.port = port
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
def create(