Add remote extension test with library component (#11301)

The current test was just SQL files only, but we also want to test a
remote extension which includes a loadable library. With both extensions
we should cover a larger portion of compute_ctl's remote extension code
paths.

Fixes: https://github.com/neondatabase/neon/issues/11146

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2025-04-24 17:33:46 -05:00
committed by GitHub
parent 5ba7315c84
commit 2526f6aea1
12 changed files with 436 additions and 73 deletions

View File

@@ -16,4 +16,5 @@ pytest_plugins = (
"fixtures.slow",
"fixtures.reruns",
"fixtures.fast_import",
"fixtures.pg_config",
)

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
import shlex
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast, final
import pytest
if TYPE_CHECKING:
from collections.abc import Iterator
from typing import IO
from fixtures.neon_fixtures import PgBin
@final
class PgConfigKey(StrEnum):
BINDIR = "BINDIR"
DOCDIR = "DOCDIR"
HTMLDIR = "HTMLDIR"
INCLUDEDIR = "INCLUDEDIR"
PKGINCLUDEDIR = "PKGINCLUDEDIR"
INCLUDEDIR_SERVER = "INCLUDEDIR-SERVER"
LIBDIR = "LIBDIR"
PKGLIBDIR = "PKGLIBDIR"
LOCALEDIR = "LOCALEDIR"
MANDIR = "MANDIR"
SHAREDIR = "SHAREDIR"
SYSCONFDIR = "SYSCONFDIR"
PGXS = "PGXS"
CONFIGURE = "CONFIGURE"
CC = "CC"
CPPFLAGS = "CPPFLAGS"
CFLAGS = "CFLAGS"
CFLAGS_SL = "CFLAGS_SL"
LDFLAGS = "LDFLAGS"
LDFLAGS_EX = "LDFLAGS_EX"
LDFLAGS_SL = "LDFLAGS_SL"
LIBS = "LIBS"
VERSION = "VERSION"
if TYPE_CHECKING:
# TODO: This could become a TypedDict if Python ever allows StrEnums to be
# keys.
PgConfig = dict[PgConfigKey, str | Path | list[str]]
def __get_pg_config(pg_bin: PgBin) -> PgConfig:
"""Get pg_config values by invoking the command"""
cmd = pg_bin.run_nonblocking(["pg_config"])
cmd.wait()
if cmd.returncode != 0:
pytest.exit("")
assert cmd.stdout
stdout = cast("IO[str]", cmd.stdout)
# Parse the output into a dictionary
values: PgConfig = {}
for line in stdout.readlines():
if "=" in line:
key, value = line.split("=", 1)
value = value.strip()
match PgConfigKey(key.strip()):
case (
(
PgConfigKey.CC
| PgConfigKey.CPPFLAGS
| PgConfigKey.CFLAGS
| PgConfigKey.CFLAGS_SL
| PgConfigKey.LDFLAGS
| PgConfigKey.LDFLAGS_EX
| PgConfigKey.LDFLAGS_SL
| PgConfigKey.LIBS
) as k
):
values[k] = shlex.split(value)
case (
(
PgConfigKey.BINDIR
| PgConfigKey.DOCDIR
| PgConfigKey.HTMLDIR
| PgConfigKey.INCLUDEDIR
| PgConfigKey.PKGINCLUDEDIR
| PgConfigKey.INCLUDEDIR_SERVER
| PgConfigKey.LIBDIR
| PgConfigKey.PKGLIBDIR
| PgConfigKey.LOCALEDIR
| PgConfigKey.MANDIR
| PgConfigKey.SHAREDIR
| PgConfigKey.SYSCONFDIR
| PgConfigKey.PGXS
) as k
):
values[k] = Path(value)
case _ as k:
values[k] = value
return values
@pytest.fixture(scope="function")
def pg_config(pg_bin: PgBin) -> Iterator[PgConfig]:
"""Dictionary of all pg_config values from the system"""
yield __get_pg_config(pg_bin)
@pytest.fixture(scope="function")
def pg_config_bindir(pg_config: PgConfig) -> Iterator[Path]:
"""BINDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.BINDIR])
@pytest.fixture(scope="function")
def pg_config_docdir(pg_config: PgConfig) -> Iterator[Path]:
"""DOCDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.DOCDIR])
@pytest.fixture(scope="function")
def pg_config_htmldir(pg_config: PgConfig) -> Iterator[Path]:
"""HTMLDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.HTMLDIR])
@pytest.fixture(scope="function")
def pg_config_includedir(
pg_config: dict[PgConfigKey, str | Path | list[str]],
) -> Iterator[Path]:
"""INCLUDEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.INCLUDEDIR])
@pytest.fixture(scope="function")
def pg_config_pkgincludedir(pg_config: PgConfig) -> Iterator[Path]:
"""PKGINCLUDEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PKGINCLUDEDIR])
@pytest.fixture(scope="function")
def pg_config_includedir_server(pg_config: PgConfig) -> Iterator[Path]:
"""INCLUDEDIR-SERVER value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER])
@pytest.fixture(scope="function")
def pg_config_libdir(pg_config: PgConfig) -> Iterator[Path]:
"""LIBDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.LIBDIR])
@pytest.fixture(scope="function")
def pg_config_pkglibdir(pg_config: PgConfig) -> Iterator[Path]:
"""PKGLIBDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PKGLIBDIR])
@pytest.fixture(scope="function")
def pg_config_localedir(pg_config: PgConfig) -> Iterator[Path]:
"""LOCALEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.LOCALEDIR])
@pytest.fixture(scope="function")
def pg_config_mandir(pg_config: PgConfig) -> Iterator[Path]:
"""MANDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.MANDIR])
@pytest.fixture(scope="function")
def pg_config_sharedir(pg_config: PgConfig) -> Iterator[Path]:
"""SHAREDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.SHAREDIR])
@pytest.fixture(scope="function")
def pg_config_sysconfdir(pg_config: PgConfig) -> Iterator[Path]:
"""SYSCONFDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.SYSCONFDIR])
@pytest.fixture(scope="function")
def pg_config_pgxs(pg_config: PgConfig) -> Iterator[Path]:
"""PGXS value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PGXS])
@pytest.fixture(scope="function")
def pg_config_configure(pg_config: PgConfig) -> Iterator[str]:
"""CONFIGURE value from pg_config"""
yield cast("str", pg_config[PgConfigKey.CONFIGURE])
@pytest.fixture(scope="function")
def pg_config_cc(pg_config: PgConfig) -> Iterator[list[str]]:
"""CC value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CC])
@pytest.fixture(scope="function")
def pg_config_cppflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""CPPFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CPPFLAGS])
@pytest.fixture(scope="function")
def pg_config_cflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""CFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CFLAGS])
@pytest.fixture(scope="function")
def pg_config_cflags_sl(pg_config: PgConfig) -> Iterator[list[str]]:
"""CFLAGS_SL value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL])
@pytest.fixture(scope="function")
def pg_config_ldflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS])
@pytest.fixture(scope="function")
def pg_config_ldflags_ex(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS_EX value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX])
@pytest.fixture(scope="function")
def pg_config_ldflags_sl(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS_SL value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL])
@pytest.fixture(scope="function")
def pg_config_libs(pg_config: PgConfig) -> Iterator[list[str]]:
"""LIBS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LIBS])
@pytest.fixture(scope="function")
def pg_config_version(pg_config: PgConfig) -> Iterator[str]:
"""VERSION value from pg_config"""
yield cast("str", pg_config[PgConfigKey.VERSION])

View File

@@ -1,12 +0,0 @@
\echo Use "CREATE EXTENSION test_extension" to load this file. \quit
CREATE SCHEMA test_extension;
CREATE FUNCTION test_extension.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$
BEGIN
RAISE NOTICE 'Have a great day';
END;
$$ LANGUAGE 'plpgsql';

View File

@@ -1,6 +1,6 @@
\echo Use "ALTER EXTENSION test_extension UPDATE TO '1.1'" to load this file. \quit
\echo Use "ALTER EXTENSION test_extension_sql_only UPDATE TO '1.1'" to load this file. \quit
CREATE FUNCTION test_extension.fun_fact()
CREATE FUNCTION test_extension_sql_only.fun_fact()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$

View File

@@ -0,0 +1,12 @@
\echo Use "CREATE EXTENSION test_extension_sql_only" to load this file. \quit
CREATE SCHEMA test_extension_sql_only;
CREATE FUNCTION test_extension_sql_only.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$
BEGIN
RAISE NOTICE 'Have a great day';
END;
$$ LANGUAGE 'plpgsql';

View File

@@ -0,0 +1 @@
comment = 'Test extension SQL only'

View File

@@ -0,0 +1,6 @@
\echo Use "ALTER EXTENSION test_extension_with_lib UPDATE TO '1.1'" to load this file. \quit
CREATE FUNCTION test_extension_with_lib.fun_fact()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS 'MODULE_PATHNAME', 'fun_fact' LANGUAGE C;

View File

@@ -0,0 +1,8 @@
\echo Use "CREATE EXTENSION test_extension_with_lib" to load this file. \quit
CREATE SCHEMA test_extension_with_lib;
CREATE FUNCTION test_extension_with_lib.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS 'MODULE_PATHNAME', 'motd' LANGUAGE C;

View File

@@ -0,0 +1,34 @@
#include <postgres.h>
#include <fmgr.h>
PG_MODULE_MAGIC;
PG_FUNCTION_INFO_V1(motd);
PG_FUNCTION_INFO_V1(fun_fact);
/* Old versions of Postgres didn't pre-declare this in fmgr.h */
#if PG_MAJORVERSION_NUM <= 15
void _PG_init(void);
#endif
void
_PG_init(void)
{
}
Datum
motd(PG_FUNCTION_ARGS)
{
elog(NOTICE, "Have a great day");
PG_RETURN_VOID();
}
Datum
fun_fact(PG_FUNCTION_ARGS)
{
elog(NOTICE, "Neon has a melting point of -246.08 C");
PG_RETURN_VOID();
}

View File

@@ -0,0 +1,2 @@
comment = 'Test extension with lib'
module_pathname = '$libdir/test_extension_with_lib'

View File

@@ -4,12 +4,17 @@ import os
import platform
import shutil
import tarfile
from typing import TYPE_CHECKING
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast, final
import pytest
import zstandard
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.paths import BASE_DIR
from fixtures.pg_config import PgConfigKey
from fixtures.utils import subprocess_capture
from werkzeug.wrappers.response import Response
if TYPE_CHECKING:
@@ -20,6 +25,7 @@ if TYPE_CHECKING:
from fixtures.neon_fixtures import (
NeonEnvBuilder,
)
from fixtures.pg_config import PgConfig
from fixtures.pg_version import PgVersion
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
@@ -46,46 +52,108 @@ def neon_env_builder_local(
return neon_env_builder
@final
class RemoteExtension(StrEnum):
SQL_ONLY = "test_extension_sql_only"
WITH_LIB = "test_extension_with_lib"
@property
def compressed_tarball_name(self) -> str:
return f"{self.tarball_name}.zst"
@property
def control_file_name(self) -> str:
return f"{self}.control"
@property
def directory(self) -> Path:
return BASE_DIR / "test_runner" / "regress" / "data" / "test_remote_extensions" / self
@property
def shared_library_name(self) -> str:
return f"{self}.so"
@property
def tarball_name(self) -> str:
return f"{self}.tar"
def archive_route(self, build_tag: str, arch: str, pg_version: PgVersion) -> str:
return f"{build_tag}/{arch}/v{pg_version}/extensions/{self.compressed_tarball_name}"
def build(self, pg_config: PgConfig, output_dir: Path) -> None:
if self is not RemoteExtension.WITH_LIB:
return
cmd: list[str] = [
*cast("list[str]", pg_config[PgConfigKey.CC]),
*cast("list[str]", pg_config[PgConfigKey.CPPFLAGS]),
*["-I", str(cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER]))],
*cast("list[str]", pg_config[PgConfigKey.CFLAGS]),
*cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL]),
"-shared",
*["-o", str(output_dir / self.shared_library_name)],
str(self.directory / "src" / f"{self}.c"),
]
subprocess_capture(output_dir, cmd, check=True)
def control_file_contents(self) -> str:
with open(self.directory / self.control_file_name, encoding="utf-8") as f:
return f.read()
def files(self, output_dir: Path) -> dict[Path, str]:
files = {
# self.directory / self.control_file_name: f"share/extension/{self.control_file_name}",
self.directory / "sql" / f"{self}--1.0.sql": f"share/extension/{self}--1.0.sql",
self.directory
/ "sql"
/ f"{self}--1.0--1.1.sql": f"share/extension/{self}--1.0--1.1.sql",
}
if self is RemoteExtension.WITH_LIB:
files[output_dir / self.shared_library_name] = f"lib/{self.shared_library_name}"
return files
def package(self, output_dir: Path) -> Path:
tarball = output_dir / self.tarball_name
with tarfile.open(tarball, "x") as tarf:
for file, arcname in self.files(output_dir).items():
tarf.add(file, arcname=arcname)
return tarball
def remove(self, output_dir: Path, pg_version: PgVersion) -> None:
for file in self.files(output_dir).values():
if file.startswith("share/extension"):
file = f"share/postgresql/extension/{os.path.basename(file)}"
if file.startswith("lib"):
file = f"lib/postgresql/{os.path.basename(file)}"
(output_dir / "pg_install" / f"v{pg_version}" / file).unlink()
@pytest.mark.parametrize(
"extension",
(RemoteExtension.SQL_ONLY, RemoteExtension.WITH_LIB),
ids=["sql_only", "with_lib"],
)
def test_remote_extensions(
httpserver: HTTPServer,
neon_env_builder_local: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
test_output_dir: Path,
base_dir: Path,
pg_version: PgVersion,
pg_config: PgConfig,
extension: RemoteExtension,
):
# Setup a mock nginx S3 gateway which will return our test extension.
(host, port) = httpserver_listen_address
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
build_tag = os.environ.get("BUILD_TAG", "latest")
# We have decided to use the Go naming convention due to Kubernetes.
arch = platform.machine()
match arch:
case "aarch64":
arch = "arm64"
case "x86_64":
arch = "amd64"
case _:
pass
archive_route = f"{build_tag}/{arch}/v{pg_version}/extensions/test_extension.tar.zst"
tarball = test_output_dir / "test_extension.tar"
extension_dir = (
base_dir / "test_runner" / "regress" / "data" / "test_remote_extensions" / "test_extension"
)
# Create tarball
with tarfile.open(tarball, "x") as tarf:
tarf.add(
extension_dir / "sql" / "test_extension--1.0.sql",
arcname="share/extension/test_extension--1.0.sql",
)
tarf.add(
extension_dir / "sql" / "test_extension--1.0--1.1.sql",
arcname="share/extension/test_extension--1.0--1.1.sql",
)
extension.build(pg_config, test_output_dir)
tarball = extension.package(test_output_dir)
def handler(request: Request) -> Response:
log.info(f"request: {request}")
@@ -104,8 +172,19 @@ def test_remote_extensions(
direct_passthrough=True,
)
# We have decided to use the Go naming convention due to Kubernetes.
arch = platform.machine()
match arch:
case "aarch64":
arch = "arm64"
case "x86_64":
arch = "amd64"
case _:
pass
httpserver.expect_request(
f"/pg-ext-s3-gateway/{archive_route}", method="GET"
f"/pg-ext-s3-gateway/{extension.archive_route(build_tag=os.environ.get('BUILD_TAG', 'latest'), arch=arch, pg_version=pg_version)}",
method="GET",
).respond_with_handler(handler)
# Start a compute node with remote_extension spec
@@ -114,21 +193,18 @@ def test_remote_extensions(
env.create_branch("test_remote_extensions")
endpoint = env.endpoints.create("test_remote_extensions")
with open(extension_dir / "test_extension.control", encoding="utf-8") as f:
control_data = f.read()
# mock remote_extensions spec
spec: dict[str, Any] = {
"public_extensions": ["test_extension"],
"public_extensions": [extension],
"custom_extensions": None,
"library_index": {
"test_extension": "test_extension",
extension: extension,
},
"extension_data": {
"test_extension": {
extension: {
"archive_path": "",
"control_data": {
"test_extension.control": control_data,
extension.control_file_name: extension.control_file_contents(),
},
},
},
@@ -141,8 +217,8 @@ def test_remote_extensions(
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute("CREATE EXTENSION test_extension VERSION '1.0'")
cur.execute("SELECT test_extension.motd()")
cur.execute(f"CREATE EXTENSION {extension} VERSION '1.0'")
cur.execute(f"SELECT {extension}.motd()")
httpserver.check()
@@ -153,7 +229,7 @@ def test_remote_extensions(
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": "test_extension.tar.zst"},
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests:
@@ -162,20 +238,7 @@ def test_remote_extensions(
endpoint.stop()
# Remove the extension files to force a redownload of the extension.
for file in (
"test_extension.control",
"test_extension--1.0.sql",
"test_extension--1.0--1.1.sql",
):
(
test_output_dir
/ "pg_install"
/ f"v{pg_version}"
/ "share"
/ "postgresql"
/ "extension"
/ file
).unlink()
extension.remove(test_output_dir, pg_version)
endpoint.start(remote_ext_config=extensions_endpoint)
@@ -183,8 +246,8 @@ def test_remote_extensions(
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute("ALTER EXTENSION test_extension UPDATE TO '1.1'")
cur.execute("SELECT test_extension.fun_fact()")
cur.execute(f"ALTER EXTENSION {extension} UPDATE TO '1.1'")
cur.execute(f"SELECT {extension}.fun_fact()")
# Check that we properly recorded downloads in the metrics
client = endpoint.http_client()
@@ -193,7 +256,7 @@ def test_remote_extensions(
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": "test_extension.tar.zst"},
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests: