Files
neon/test_runner/regress/test_download_extensions.py
Dimitri Fontaine 1a45b2ec90 Review security model for executing Event Trigger code. (#12463)
When a function is owned by a superuser (bootstrap user or otherwise),
we consider it safe to run it. Only a superuser could have installed it,
typically from CREATE EXTENSION script: we trust the code to execute.

## Problem

This is intended to solve running pg_graphql Event Triggers
graphql_watch_ddl and graphql_watch_drop which are executing the secdef
function graphql.increment_schema_version().

## Summary of changes

Allow executing Event Trigger function owned by a superuser and with
SECURITY DEFINER properties. The Event Trigger code runs with superuser
privileges, and we consider that it's fine.

---------

Co-authored-by: Tristan Partin <tristan.partin@databricks.com>
2025-07-10 08:06:33 +00:00

267 lines
9.1 KiB
Python

from __future__ import annotations
import os
import platform
import tarfile
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast, final
import pytest
import zstandard
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.paths import BASE_DIR
from fixtures.pg_config import PgConfigKey
from fixtures.utils import WITH_SANITIZERS, subprocess_capture
from werkzeug.wrappers.response import Response
if TYPE_CHECKING:
from pathlib import Path
from typing import Any
from fixtures.httpserver import ListenAddress
from fixtures.neon_fixtures import (
NeonEnvBuilder,
)
from fixtures.pg_config import PgConfig
from fixtures.pg_version import PgVersion
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
@final
class RemoteExtension(StrEnum):
SQL_ONLY = "test_extension_sql_only"
WITH_LIB = "test_extension_with_lib"
@property
def compressed_tarball_name(self) -> str:
return f"{self.tarball_name}.zst"
@property
def control_file_name(self) -> str:
return f"{self}.control"
@property
def directory(self) -> Path:
return BASE_DIR / "test_runner" / "regress" / "data" / "test_remote_extensions" / self
@property
def shared_library_name(self) -> str:
return f"{self}.so"
@property
def tarball_name(self) -> str:
return f"{self}.tar"
def archive_route(self, build_tag: str, arch: str, pg_version: PgVersion) -> str:
return f"{build_tag}/{arch}/v{pg_version}/extensions/{self.compressed_tarball_name}"
def build(self, pg_config: PgConfig, output_dir: Path) -> None:
if self is not RemoteExtension.WITH_LIB:
return
cmd: list[str] = [
*cast("list[str]", pg_config[PgConfigKey.CC]),
*cast("list[str]", pg_config[PgConfigKey.CPPFLAGS]),
*["-I", str(cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER]))],
*cast("list[str]", pg_config[PgConfigKey.CFLAGS]),
*cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL]),
"-shared",
*["-o", str(output_dir / self.shared_library_name)],
str(self.directory / "src" / f"{self}.c"),
]
subprocess_capture(output_dir, cmd, check=True)
def control_file_contents(self) -> str:
with open(self.directory / self.control_file_name, encoding="utf-8") as f:
return f.read()
def files(self, output_dir: Path) -> dict[Path, str]:
files = {
# self.directory / self.control_file_name: f"share/extension/{self.control_file_name}",
self.directory / "sql" / f"{self}--1.0.sql": f"share/extension/{self}--1.0.sql",
self.directory
/ "sql"
/ f"{self}--1.0--1.1.sql": f"share/extension/{self}--1.0--1.1.sql",
}
if self is RemoteExtension.WITH_LIB:
files[output_dir / self.shared_library_name] = f"lib/{self.shared_library_name}"
return files
def package(self, output_dir: Path) -> Path:
tarball = output_dir / self.tarball_name
with tarfile.open(tarball, "x") as tarf:
for file, arcname in self.files(output_dir).items():
tarf.add(file, arcname=arcname)
return tarball
def remove(self, output_dir: Path, pg_version: PgVersion) -> None:
for file in self.files(output_dir).values():
if file.startswith("share/extension"):
file = f"share/postgresql/extension/{os.path.basename(file)}"
if file.startswith("lib"):
file = f"lib/postgresql/{os.path.basename(file)}"
(output_dir / "pg_install" / f"v{pg_version}" / file).unlink()
@pytest.mark.parametrize(
"extension",
(RemoteExtension.SQL_ONLY, RemoteExtension.WITH_LIB),
ids=["sql_only", "with_lib"],
)
def test_remote_extensions(
httpserver: HTTPServer,
neon_env_builder_local: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
test_output_dir: Path,
pg_version: PgVersion,
pg_config: PgConfig,
extension: RemoteExtension,
):
if WITH_SANITIZERS and extension is RemoteExtension.WITH_LIB:
pytest.skip(
"""
For this test to work with sanitizers enabled, we would need to
compile the dummy Postgres extension with the same CFLAGS that we
compile Postgres and the neon extension with to link the sanitizers.
"""
)
# Setup a mock nginx S3 gateway which will return our test extension.
(host, port) = httpserver_listen_address
remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway"
log.info(f"remote extensions base URL: {remote_ext_base_url}")
extension.build(pg_config, test_output_dir)
tarball = extension.package(test_output_dir)
def handler(request: Request) -> Response:
log.info(f"request: {request}")
# Compress tarball
compressor = zstandard.ZstdCompressor()
with open(tarball, "rb") as f:
compressed_data = compressor.compress(f.read())
return Response(
compressed_data,
mimetype="application/octet-stream",
headers=[
("Content-Length", str(len(compressed_data))),
],
direct_passthrough=True,
)
# We have decided to use the Go naming convention due to Kubernetes.
arch = platform.machine()
match arch:
case "aarch64":
arch = "arm64"
case "x86_64":
arch = "amd64"
case _:
pass
httpserver.expect_request(
f"/pg-ext-s3-gateway/{extension.archive_route(build_tag=os.environ.get('BUILD_TAG', 'latest'), arch=arch, pg_version=pg_version)}",
method="GET",
).respond_with_handler(handler)
# Start a compute node with remote_extension spec
# and check that it can download the extensions and use them to CREATE EXTENSION.
env = neon_env_builder_local.init_start()
env.create_branch("test_remote_extensions")
endpoint = env.endpoints.create("test_remote_extensions")
# mock remote_extensions spec
spec: dict[str, Any] = {
"public_extensions": [extension],
"custom_extensions": None,
"library_index": {
extension: extension,
},
"extension_data": {
extension: {
"archive_path": "",
"control_data": {
extension.control_file_name: extension.control_file_contents(),
},
},
},
}
endpoint.create_remote_extension_spec(spec)
endpoint.start(remote_ext_base_url=remote_ext_base_url)
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute(f"CREATE EXTENSION {extension} VERSION '1.0'")
cur.execute(f"SELECT {extension}.motd()")
httpserver.check()
# Check that we properly recorded downloads in the metrics
client = endpoint.http_client()
raw_metrics = client.metrics()
metrics = parse_metrics(raw_metrics)
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests:
assert sample.value == 1
endpoint.stop()
# Remove the extension files to force a redownload of the extension.
extension.remove(test_output_dir, pg_version)
endpoint.start(remote_ext_base_url=remote_ext_base_url)
# Test that ALTER EXTENSION UPDATE statements also fetch remote extensions.
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute(f"ALTER EXTENSION {extension} UPDATE TO '1.1'")
cur.execute(f"SELECT {extension}.fun_fact()")
# Check that we properly recorded downloads in the metrics
client = endpoint.http_client()
raw_metrics = client.metrics()
metrics = parse_metrics(raw_metrics)
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests:
assert sample.value == 1
# TODO
# 1. Test downloading remote library.
#
# 2. Test a complex extension, which has multiple extensions in one archive
# using postgis as an example
#
# 3.Test that extension is downloaded after endpoint restart,
# when the library is used in the query.
# Run the test with multiple simultaneous connections to an endpoint.
# to ensure that the extension is downloaded only once.
#
# 4. Test that private extensions are only downloaded when they are present in the spec.
#