diff --git a/test_runner/batch_others/test_createuser.py b/test_runner/batch_others/test_createuser.py index efb2af3f07..f4bbbc8a7a 100644 --- a/test_runner/batch_others/test_createuser.py +++ b/test_runner/batch_others/test_createuser.py @@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv): pg2 = env.postgres.create_start('test_createuser2') # Test that you can connect to new branch as a new user - assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )] + assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )] diff --git a/test_runner/batch_others/test_parallel_copy.py b/test_runner/batch_others/test_parallel_copy.py index 4b7cc58d42..a44acecf21 100644 --- a/test_runner/batch_others/test_parallel_copy.py +++ b/test_runner/batch_others/test_parallel_copy.py @@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str) copy_input = repeat_bytes(buf.read(), 5000) pg_conn = await pg.connect_async() + + # PgProtocol.connect_async sets statement_timeout to 2 minutes. + # That's not enough for this test, on a slow system in debug mode. + await pg_conn.execute("SET statement_timeout='300s'") + await pg_conn.copy_to_table(table_name, source=copy_input) diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index 8f87ff041f..dffcd7cc61 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol): tenant_id: uuid.UUID, listen_addr: str, port: int): - super().__init__(host=listen_addr, port=port, username='zenith_admin') + super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres') self.pgdata_dir: str = pgdata_dir self.pg_bin: PgBin = pg_bin diff --git a/test_runner/batch_pg_regress/test_isolation.py b/test_runner/batch_pg_regress/test_isolation.py index ddafc3815b..cde56d9b88 100644 --- a/test_runner/batch_pg_regress/test_isolation.py +++ b/test_runner/batch_pg_regress/test_isolation.py @@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_pg_regress.py b/test_runner/batch_pg_regress/test_pg_regress.py index 5199f65216..07d2574f4a 100644 --- a/test_runner/batch_pg_regress/test_pg_regress.py +++ b/test_runner/batch_pg_regress/test_pg_regress.py @@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_zenith_regress.py b/test_runner/batch_pg_regress/test_zenith_regress.py index 31d5b07093..2b57137d16 100644 --- a/test_runner/batch_pg_regress/test_zenith_regress.py +++ b/test_runner/batch_pg_regress/test_zenith_regress.py @@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c log.info(pg_regress_command) env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index a95809687a..41d1443880 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -27,6 +27,7 @@ from dataclasses import dataclass # Type-related stuff from psycopg2.extensions import connection as PgConnection +from psycopg2.extensions import make_dsn, parse_dsn from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing_extensions import Literal @@ -238,98 +239,69 @@ def port_distributor(worker_base_port): class PgProtocol: """ Reusable connection logic """ - def __init__(self, - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - dbname: Optional[str] = None, - schema: Optional[str] = None): - self.host = host - self.port = port - self.username = username - self.password = password - self.dbname = dbname - self.schema = schema + def __init__(self, **kwargs): + self.default_options = kwargs - def connstr(self, - *, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - statement_timeout_ms: Optional[int] = None) -> str: + def connstr(self, **kwargs) -> str: """ Build a libpq connection string for the Postgres instance. """ + return str(make_dsn(**self.conn_options(**kwargs))) - username = username or self.username - password = password or self.password - dbname = dbname or self.dbname or "postgres" - schema = schema or self.schema - res = f'host={self.host} port={self.port} dbname={dbname}' + def conn_options(self, **kwargs): + conn_options = self.default_options.copy() + if 'dsn' in kwargs: + conn_options.update(parse_dsn(kwargs['dsn'])) + conn_options.update(kwargs) - if username: - res = f'{res} user={username}' - - if password: - res = f'{res} password={password}' - - if schema: - res = f"{res} options='-c search_path={schema}'" - - if statement_timeout_ms: - res = f"{res} options='-c statement_timeout={statement_timeout_ms}'" - - return res + # Individual statement timeout in seconds. 2 minutes should be + # enough for our tests, but if you need a longer, you can + # change it by calling "SET statement_timeout" after + # connecting. + if 'options' in conn_options: + conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options'] + else: + conn_options['options'] = "-cstatement_timeout=120s" + return conn_options # autocommit=True here by default because that's what we need most of the time - def connect( - self, - *, - autocommit=True, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - # individual statement timeout in seconds, 2 minutes should be enough for our tests - statement_timeout: Optional[int] = 120 - ) -> PgConnection: + def connect(self, autocommit=True, **kwargs) -> PgConnection: """ Connect to the node. Returns psycopg2's connection object. This method passes all extra params to connstr. """ + conn = psycopg2.connect(**self.conn_options(**kwargs)) - conn = psycopg2.connect( - self.connstr(dbname=dbname, - schema=schema, - username=username, - password=password, - statement_timeout_ms=statement_timeout * - 1000 if statement_timeout else None)) # WARNING: this setting affects *all* tests! conn.autocommit = autocommit return conn - async def connect_async(self, - *, - dbname: str = 'postgres', - username: Optional[str] = None, - password: Optional[str] = None) -> asyncpg.Connection: + async def connect_async(self, **kwargs) -> asyncpg.Connection: """ Connect to the node from async python. Returns asyncpg's connection object. """ - conn = await asyncpg.connect( - host=self.host, - port=self.port, - database=dbname, - user=username or self.username, - password=password, - ) - return conn + # asyncpg takes slightly different options than psycopg2. Try + # to convert the defaults from the psycopg2 format. + + # The psycopg2 option 'dbname' is called 'database' is asyncpg + conn_options = self.conn_options(**kwargs) + if 'dbname' in conn_options: + conn_options['database'] = conn_options.pop('dbname') + + # Convert options='-c=' to server_settings + if 'options' in conn_options: + options = conn_options.pop('options') + for match in re.finditer('-c(\w*)=(\w*)', options): + key = match.group(1) + val = match.group(2) + if 'server_options' in conn_options: + conn_options['server_settings'].update({key: val}) + else: + conn_options['server_settings'] = {key: val} + return await asyncpg.connect(**conn_options) def safe_psql(self, query: str, **kwargs: Any) -> List[Any]: """ @@ -1149,10 +1121,10 @@ class ZenithPageserver(PgProtocol): port: PageserverPort, remote_storage: Optional[RemoteStorage] = None, config_override: Optional[str] = None): - super().__init__(host='localhost', port=port.pg, username='zenith_admin') + super().__init__(host='localhost', port=port.pg, user='zenith_admin') self.env = env self.running = False - self.service_port = port # do not shadow PgProtocol.port which is just int + self.service_port = port self.remote_storage = remote_storage self.config_override = config_override @@ -1291,7 +1263,7 @@ def pg_bin(test_output_dir: str) -> PgBin: class VanillaPostgres(PgProtocol): def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int): - super().__init__(host='localhost', port=port) + super().__init__(host='localhost', port=port, dbname='postgres') self.pgdatadir = pgdatadir self.pg_bin = pg_bin self.running = False @@ -1335,8 +1307,14 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]: class ZenithProxy(PgProtocol): def __init__(self, port: int): - super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port) + super().__init__(host="127.0.0.1", + user="pytest", + password="pytest", + port=port, + dbname='postgres') self.http_port = 7001 + self.host = "127.0.0.1" + self.port = port self._popen: Optional[subprocess.Popen[bytes]] = None def start_static(self, addr="127.0.0.1:5432") -> None: @@ -1380,13 +1358,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]: class Postgres(PgProtocol): """ An object representing a running postgres daemon. """ def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int): - super().__init__(host='localhost', port=port, username='zenith_admin') - + super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres') self.env = env self.running = False self.node_name: Optional[str] = None # dubious, see asserts below self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA self.tenant_id = tenant_id + self.port = port # path to conf is /pgdatadirs/tenants///postgresql.conf def create(