mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 21:12:55 +00:00
Refactor connection option handling in python tests
The PgProtocol.connect() function took extra options for username, database, etc. Remove those options, and have a generic way for each subclass of PgProtocol to provide some default options, with the capability override them in the connect() call.
This commit is contained in:
@@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv):
|
|||||||
pg2 = env.postgres.create_start('test_createuser2')
|
pg2 = env.postgres.create_start('test_createuser2')
|
||||||
|
|
||||||
# Test that you can connect to new branch as a new user
|
# Test that you can connect to new branch as a new user
|
||||||
assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )]
|
assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )]
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str)
|
|||||||
copy_input = repeat_bytes(buf.read(), 5000)
|
copy_input = repeat_bytes(buf.read(), 5000)
|
||||||
|
|
||||||
pg_conn = await pg.connect_async()
|
pg_conn = await pg.connect_async()
|
||||||
|
|
||||||
|
# PgProtocol.connect_async sets statement_timeout to 2 minutes.
|
||||||
|
# That's not enough for this test, on a slow system in debug mode.
|
||||||
|
await pg_conn.execute("SET statement_timeout='300s'")
|
||||||
|
|
||||||
await pg_conn.copy_to_table(table_name, source=copy_input)
|
await pg_conn.copy_to_table(table_name, source=copy_input)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol):
|
|||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
listen_addr: str,
|
listen_addr: str,
|
||||||
port: int):
|
port: int):
|
||||||
super().__init__(host=listen_addr, port=port, username='zenith_admin')
|
super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres')
|
||||||
|
|
||||||
self.pgdata_dir: str = pgdata_dir
|
self.pgdata_dir: str = pgdata_dir
|
||||||
self.pg_bin: PgBin = pg_bin
|
self.pg_bin: PgBin = pg_bin
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys
|
|||||||
]
|
]
|
||||||
|
|
||||||
env_vars = {
|
env_vars = {
|
||||||
'PGPORT': str(pg.port),
|
'PGPORT': str(pg.default_options['port']),
|
||||||
'PGUSER': pg.username,
|
'PGUSER': pg.default_options['user'],
|
||||||
'PGHOST': pg.host,
|
'PGHOST': pg.default_options['host'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the command.
|
# Run the command.
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin,
|
|||||||
]
|
]
|
||||||
|
|
||||||
env_vars = {
|
env_vars = {
|
||||||
'PGPORT': str(pg.port),
|
'PGPORT': str(pg.default_options['port']),
|
||||||
'PGUSER': pg.username,
|
'PGUSER': pg.default_options['user'],
|
||||||
'PGHOST': pg.host,
|
'PGHOST': pg.default_options['host'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the command.
|
# Run the command.
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c
|
|||||||
|
|
||||||
log.info(pg_regress_command)
|
log.info(pg_regress_command)
|
||||||
env_vars = {
|
env_vars = {
|
||||||
'PGPORT': str(pg.port),
|
'PGPORT': str(pg.default_options['port']),
|
||||||
'PGUSER': pg.username,
|
'PGUSER': pg.default_options['user'],
|
||||||
'PGHOST': pg.host,
|
'PGHOST': pg.default_options['host'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the command.
|
# Run the command.
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
# Type-related stuff
|
# Type-related stuff
|
||||||
from psycopg2.extensions import connection as PgConnection
|
from psycopg2.extensions import connection as PgConnection
|
||||||
|
from psycopg2.extensions import make_dsn, parse_dsn
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
@@ -238,98 +239,69 @@ def port_distributor(worker_base_port):
|
|||||||
|
|
||||||
class PgProtocol:
|
class PgProtocol:
|
||||||
""" Reusable connection logic """
|
""" Reusable connection logic """
|
||||||
def __init__(self,
|
def __init__(self, **kwargs):
|
||||||
host: str,
|
self.default_options = kwargs
|
||||||
port: int,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
dbname: Optional[str] = None,
|
|
||||||
schema: Optional[str] = None):
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.username = username
|
|
||||||
self.password = password
|
|
||||||
self.dbname = dbname
|
|
||||||
self.schema = schema
|
|
||||||
|
|
||||||
def connstr(self,
|
def connstr(self, **kwargs) -> str:
|
||||||
*,
|
|
||||||
dbname: Optional[str] = None,
|
|
||||||
schema: Optional[str] = None,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
statement_timeout_ms: Optional[int] = None) -> str:
|
|
||||||
"""
|
"""
|
||||||
Build a libpq connection string for the Postgres instance.
|
Build a libpq connection string for the Postgres instance.
|
||||||
"""
|
"""
|
||||||
|
return str(make_dsn(**self.conn_options(**kwargs)))
|
||||||
|
|
||||||
username = username or self.username
|
def conn_options(self, **kwargs):
|
||||||
password = password or self.password
|
conn_options = self.default_options.copy()
|
||||||
dbname = dbname or self.dbname or "postgres"
|
if 'dsn' in kwargs:
|
||||||
schema = schema or self.schema
|
conn_options.update(parse_dsn(kwargs['dsn']))
|
||||||
res = f'host={self.host} port={self.port} dbname={dbname}'
|
conn_options.update(kwargs)
|
||||||
|
|
||||||
if username:
|
# Individual statement timeout in seconds. 2 minutes should be
|
||||||
res = f'{res} user={username}'
|
# enough for our tests, but if you need a longer, you can
|
||||||
|
# change it by calling "SET statement_timeout" after
|
||||||
if password:
|
# connecting.
|
||||||
res = f'{res} password={password}'
|
if 'options' in conn_options:
|
||||||
|
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
|
||||||
if schema:
|
else:
|
||||||
res = f"{res} options='-c search_path={schema}'"
|
conn_options['options'] = "-cstatement_timeout=120s"
|
||||||
|
return conn_options
|
||||||
if statement_timeout_ms:
|
|
||||||
res = f"{res} options='-c statement_timeout={statement_timeout_ms}'"
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
# autocommit=True here by default because that's what we need most of the time
|
# autocommit=True here by default because that's what we need most of the time
|
||||||
def connect(
|
def connect(self, autocommit=True, **kwargs) -> PgConnection:
|
||||||
self,
|
|
||||||
*,
|
|
||||||
autocommit=True,
|
|
||||||
dbname: Optional[str] = None,
|
|
||||||
schema: Optional[str] = None,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
# individual statement timeout in seconds, 2 minutes should be enough for our tests
|
|
||||||
statement_timeout: Optional[int] = 120
|
|
||||||
) -> PgConnection:
|
|
||||||
"""
|
"""
|
||||||
Connect to the node.
|
Connect to the node.
|
||||||
Returns psycopg2's connection object.
|
Returns psycopg2's connection object.
|
||||||
This method passes all extra params to connstr.
|
This method passes all extra params to connstr.
|
||||||
"""
|
"""
|
||||||
|
conn = psycopg2.connect(**self.conn_options(**kwargs))
|
||||||
|
|
||||||
conn = psycopg2.connect(
|
|
||||||
self.connstr(dbname=dbname,
|
|
||||||
schema=schema,
|
|
||||||
username=username,
|
|
||||||
password=password,
|
|
||||||
statement_timeout_ms=statement_timeout *
|
|
||||||
1000 if statement_timeout else None))
|
|
||||||
# WARNING: this setting affects *all* tests!
|
# WARNING: this setting affects *all* tests!
|
||||||
conn.autocommit = autocommit
|
conn.autocommit = autocommit
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
async def connect_async(self,
|
async def connect_async(self, **kwargs) -> asyncpg.Connection:
|
||||||
*,
|
|
||||||
dbname: str = 'postgres',
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None) -> asyncpg.Connection:
|
|
||||||
"""
|
"""
|
||||||
Connect to the node from async python.
|
Connect to the node from async python.
|
||||||
Returns asyncpg's connection object.
|
Returns asyncpg's connection object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conn = await asyncpg.connect(
|
# asyncpg takes slightly different options than psycopg2. Try
|
||||||
host=self.host,
|
# to convert the defaults from the psycopg2 format.
|
||||||
port=self.port,
|
|
||||||
database=dbname,
|
# The psycopg2 option 'dbname' is called 'database' is asyncpg
|
||||||
user=username or self.username,
|
conn_options = self.conn_options(**kwargs)
|
||||||
password=password,
|
if 'dbname' in conn_options:
|
||||||
)
|
conn_options['database'] = conn_options.pop('dbname')
|
||||||
return conn
|
|
||||||
|
# Convert options='-c<key>=<val>' to server_settings
|
||||||
|
if 'options' in conn_options:
|
||||||
|
options = conn_options.pop('options')
|
||||||
|
for match in re.finditer('-c(\w*)=(\w*)', options):
|
||||||
|
key = match.group(1)
|
||||||
|
val = match.group(2)
|
||||||
|
if 'server_options' in conn_options:
|
||||||
|
conn_options['server_settings'].update({key: val})
|
||||||
|
else:
|
||||||
|
conn_options['server_settings'] = {key: val}
|
||||||
|
return await asyncpg.connect(**conn_options)
|
||||||
|
|
||||||
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
|
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -1149,10 +1121,10 @@ class ZenithPageserver(PgProtocol):
|
|||||||
port: PageserverPort,
|
port: PageserverPort,
|
||||||
remote_storage: Optional[RemoteStorage] = None,
|
remote_storage: Optional[RemoteStorage] = None,
|
||||||
config_override: Optional[str] = None):
|
config_override: Optional[str] = None):
|
||||||
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
|
super().__init__(host='localhost', port=port.pg, user='zenith_admin')
|
||||||
self.env = env
|
self.env = env
|
||||||
self.running = False
|
self.running = False
|
||||||
self.service_port = port # do not shadow PgProtocol.port which is just int
|
self.service_port = port
|
||||||
self.remote_storage = remote_storage
|
self.remote_storage = remote_storage
|
||||||
self.config_override = config_override
|
self.config_override = config_override
|
||||||
|
|
||||||
@@ -1291,7 +1263,7 @@ def pg_bin(test_output_dir: str) -> PgBin:
|
|||||||
|
|
||||||
class VanillaPostgres(PgProtocol):
|
class VanillaPostgres(PgProtocol):
|
||||||
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
|
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
|
||||||
super().__init__(host='localhost', port=port)
|
super().__init__(host='localhost', port=port, dbname='postgres')
|
||||||
self.pgdatadir = pgdatadir
|
self.pgdatadir = pgdatadir
|
||||||
self.pg_bin = pg_bin
|
self.pg_bin = pg_bin
|
||||||
self.running = False
|
self.running = False
|
||||||
@@ -1335,8 +1307,14 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
|
|||||||
|
|
||||||
class ZenithProxy(PgProtocol):
|
class ZenithProxy(PgProtocol):
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
|
super().__init__(host="127.0.0.1",
|
||||||
|
user="pytest",
|
||||||
|
password="pytest",
|
||||||
|
port=port,
|
||||||
|
dbname='postgres')
|
||||||
self.http_port = 7001
|
self.http_port = 7001
|
||||||
|
self.host = "127.0.0.1"
|
||||||
|
self.port = port
|
||||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||||
|
|
||||||
def start_static(self, addr="127.0.0.1:5432") -> None:
|
def start_static(self, addr="127.0.0.1:5432") -> None:
|
||||||
@@ -1380,13 +1358,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
|
|||||||
class Postgres(PgProtocol):
|
class Postgres(PgProtocol):
|
||||||
""" An object representing a running postgres daemon. """
|
""" An object representing a running postgres daemon. """
|
||||||
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
|
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
|
||||||
super().__init__(host='localhost', port=port, username='zenith_admin')
|
super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres')
|
||||||
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.running = False
|
self.running = False
|
||||||
self.node_name: Optional[str] = None # dubious, see asserts below
|
self.node_name: Optional[str] = None # dubious, see asserts below
|
||||||
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
|
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
|
self.port = port
|
||||||
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
|
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
|
|||||||
Reference in New Issue
Block a user