mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
When a function is owned by a superuser (bootstrap user or otherwise), we consider it safe to run it. Only a superuser could have installed it, typically from CREATE EXTENSION script: we trust the code to execute. ## Problem This is intended to solve running pg_graphql Event Triggers graphql_watch_ddl and graphql_watch_drop which are executing the secdef function graphql.increment_schema_version(). ## Summary of changes Allow executing Event Trigger function owned by a superuser and with SECURITY DEFINER properties. The Event Trigger code runs with superuser privileges, and we consider that it's fine. --------- Co-authored-by: Tristan Partin <tristan.partin@databricks.com>
267 lines
9.1 KiB
Python
267 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import platform
|
|
import tarfile
|
|
from enum import StrEnum
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, cast, final
|
|
|
|
import pytest
|
|
import zstandard
|
|
from fixtures.log_helper import log
|
|
from fixtures.metrics import parse_metrics
|
|
from fixtures.paths import BASE_DIR
|
|
from fixtures.pg_config import PgConfigKey
|
|
from fixtures.utils import WITH_SANITIZERS, subprocess_capture
|
|
from werkzeug.wrappers.response import Response
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from fixtures.httpserver import ListenAddress
|
|
from fixtures.neon_fixtures import (
|
|
NeonEnvBuilder,
|
|
)
|
|
from fixtures.pg_config import PgConfig
|
|
from fixtures.pg_version import PgVersion
|
|
from pytest_httpserver import HTTPServer
|
|
from werkzeug.wrappers.request import Request
|
|
|
|
|
|
@final
|
|
class RemoteExtension(StrEnum):
|
|
SQL_ONLY = "test_extension_sql_only"
|
|
WITH_LIB = "test_extension_with_lib"
|
|
|
|
@property
|
|
def compressed_tarball_name(self) -> str:
|
|
return f"{self.tarball_name}.zst"
|
|
|
|
@property
|
|
def control_file_name(self) -> str:
|
|
return f"{self}.control"
|
|
|
|
@property
|
|
def directory(self) -> Path:
|
|
return BASE_DIR / "test_runner" / "regress" / "data" / "test_remote_extensions" / self
|
|
|
|
@property
|
|
def shared_library_name(self) -> str:
|
|
return f"{self}.so"
|
|
|
|
@property
|
|
def tarball_name(self) -> str:
|
|
return f"{self}.tar"
|
|
|
|
def archive_route(self, build_tag: str, arch: str, pg_version: PgVersion) -> str:
|
|
return f"{build_tag}/{arch}/v{pg_version}/extensions/{self.compressed_tarball_name}"
|
|
|
|
def build(self, pg_config: PgConfig, output_dir: Path) -> None:
|
|
if self is not RemoteExtension.WITH_LIB:
|
|
return
|
|
|
|
cmd: list[str] = [
|
|
*cast("list[str]", pg_config[PgConfigKey.CC]),
|
|
*cast("list[str]", pg_config[PgConfigKey.CPPFLAGS]),
|
|
*["-I", str(cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER]))],
|
|
*cast("list[str]", pg_config[PgConfigKey.CFLAGS]),
|
|
*cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL]),
|
|
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX]),
|
|
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL]),
|
|
"-shared",
|
|
*["-o", str(output_dir / self.shared_library_name)],
|
|
str(self.directory / "src" / f"{self}.c"),
|
|
]
|
|
|
|
subprocess_capture(output_dir, cmd, check=True)
|
|
|
|
def control_file_contents(self) -> str:
|
|
with open(self.directory / self.control_file_name, encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
def files(self, output_dir: Path) -> dict[Path, str]:
|
|
files = {
|
|
# self.directory / self.control_file_name: f"share/extension/{self.control_file_name}",
|
|
self.directory / "sql" / f"{self}--1.0.sql": f"share/extension/{self}--1.0.sql",
|
|
self.directory
|
|
/ "sql"
|
|
/ f"{self}--1.0--1.1.sql": f"share/extension/{self}--1.0--1.1.sql",
|
|
}
|
|
|
|
if self is RemoteExtension.WITH_LIB:
|
|
files[output_dir / self.shared_library_name] = f"lib/{self.shared_library_name}"
|
|
|
|
return files
|
|
|
|
def package(self, output_dir: Path) -> Path:
|
|
tarball = output_dir / self.tarball_name
|
|
with tarfile.open(tarball, "x") as tarf:
|
|
for file, arcname in self.files(output_dir).items():
|
|
tarf.add(file, arcname=arcname)
|
|
|
|
return tarball
|
|
|
|
def remove(self, output_dir: Path, pg_version: PgVersion) -> None:
|
|
for file in self.files(output_dir).values():
|
|
if file.startswith("share/extension"):
|
|
file = f"share/postgresql/extension/{os.path.basename(file)}"
|
|
if file.startswith("lib"):
|
|
file = f"lib/postgresql/{os.path.basename(file)}"
|
|
(output_dir / "pg_install" / f"v{pg_version}" / file).unlink()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"extension",
|
|
(RemoteExtension.SQL_ONLY, RemoteExtension.WITH_LIB),
|
|
ids=["sql_only", "with_lib"],
|
|
)
|
|
def test_remote_extensions(
|
|
httpserver: HTTPServer,
|
|
neon_env_builder_local: NeonEnvBuilder,
|
|
httpserver_listen_address: ListenAddress,
|
|
test_output_dir: Path,
|
|
pg_version: PgVersion,
|
|
pg_config: PgConfig,
|
|
extension: RemoteExtension,
|
|
):
|
|
if WITH_SANITIZERS and extension is RemoteExtension.WITH_LIB:
|
|
pytest.skip(
|
|
"""
|
|
For this test to work with sanitizers enabled, we would need to
|
|
compile the dummy Postgres extension with the same CFLAGS that we
|
|
compile Postgres and the neon extension with to link the sanitizers.
|
|
"""
|
|
)
|
|
|
|
# Setup a mock nginx S3 gateway which will return our test extension.
|
|
(host, port) = httpserver_listen_address
|
|
remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway"
|
|
log.info(f"remote extensions base URL: {remote_ext_base_url}")
|
|
|
|
extension.build(pg_config, test_output_dir)
|
|
tarball = extension.package(test_output_dir)
|
|
|
|
def handler(request: Request) -> Response:
|
|
log.info(f"request: {request}")
|
|
|
|
# Compress tarball
|
|
compressor = zstandard.ZstdCompressor()
|
|
with open(tarball, "rb") as f:
|
|
compressed_data = compressor.compress(f.read())
|
|
|
|
return Response(
|
|
compressed_data,
|
|
mimetype="application/octet-stream",
|
|
headers=[
|
|
("Content-Length", str(len(compressed_data))),
|
|
],
|
|
direct_passthrough=True,
|
|
)
|
|
|
|
# We have decided to use the Go naming convention due to Kubernetes.
|
|
arch = platform.machine()
|
|
match arch:
|
|
case "aarch64":
|
|
arch = "arm64"
|
|
case "x86_64":
|
|
arch = "amd64"
|
|
case _:
|
|
pass
|
|
|
|
httpserver.expect_request(
|
|
f"/pg-ext-s3-gateway/{extension.archive_route(build_tag=os.environ.get('BUILD_TAG', 'latest'), arch=arch, pg_version=pg_version)}",
|
|
method="GET",
|
|
).respond_with_handler(handler)
|
|
|
|
# Start a compute node with remote_extension spec
|
|
# and check that it can download the extensions and use them to CREATE EXTENSION.
|
|
env = neon_env_builder_local.init_start()
|
|
env.create_branch("test_remote_extensions")
|
|
endpoint = env.endpoints.create("test_remote_extensions")
|
|
|
|
# mock remote_extensions spec
|
|
spec: dict[str, Any] = {
|
|
"public_extensions": [extension],
|
|
"custom_extensions": None,
|
|
"library_index": {
|
|
extension: extension,
|
|
},
|
|
"extension_data": {
|
|
extension: {
|
|
"archive_path": "",
|
|
"control_data": {
|
|
extension.control_file_name: extension.control_file_contents(),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint.create_remote_extension_spec(spec)
|
|
|
|
endpoint.start(remote_ext_base_url=remote_ext_base_url)
|
|
|
|
with endpoint.connect() as conn:
|
|
with conn.cursor() as cur:
|
|
# Check that appropriate files were downloaded
|
|
cur.execute(f"CREATE EXTENSION {extension} VERSION '1.0'")
|
|
cur.execute(f"SELECT {extension}.motd()")
|
|
|
|
httpserver.check()
|
|
|
|
# Check that we properly recorded downloads in the metrics
|
|
client = endpoint.http_client()
|
|
raw_metrics = client.metrics()
|
|
metrics = parse_metrics(raw_metrics)
|
|
remote_ext_requests = metrics.query_all(
|
|
"compute_ctl_remote_ext_requests_total",
|
|
# Check that we properly report the filename in the metrics
|
|
{"filename": extension.compressed_tarball_name},
|
|
)
|
|
assert len(remote_ext_requests) == 1
|
|
for sample in remote_ext_requests:
|
|
assert sample.value == 1
|
|
|
|
endpoint.stop()
|
|
|
|
# Remove the extension files to force a redownload of the extension.
|
|
extension.remove(test_output_dir, pg_version)
|
|
|
|
endpoint.start(remote_ext_base_url=remote_ext_base_url)
|
|
|
|
# Test that ALTER EXTENSION UPDATE statements also fetch remote extensions.
|
|
with endpoint.connect() as conn:
|
|
with conn.cursor() as cur:
|
|
# Check that appropriate files were downloaded
|
|
cur.execute(f"ALTER EXTENSION {extension} UPDATE TO '1.1'")
|
|
cur.execute(f"SELECT {extension}.fun_fact()")
|
|
|
|
# Check that we properly recorded downloads in the metrics
|
|
client = endpoint.http_client()
|
|
raw_metrics = client.metrics()
|
|
metrics = parse_metrics(raw_metrics)
|
|
remote_ext_requests = metrics.query_all(
|
|
"compute_ctl_remote_ext_requests_total",
|
|
# Check that we properly report the filename in the metrics
|
|
{"filename": extension.compressed_tarball_name},
|
|
)
|
|
assert len(remote_ext_requests) == 1
|
|
for sample in remote_ext_requests:
|
|
assert sample.value == 1
|
|
|
|
|
|
# TODO
|
|
# 1. Test downloading remote library.
|
|
#
|
|
# 2. Test a complex extension, which has multiple extensions in one archive
|
|
# using postgis as an example
|
|
#
|
|
# 3.Test that extension is downloaded after endpoint restart,
|
|
# when the library is used in the query.
|
|
# Run the test with multiple simultaneous connections to an endpoint.
|
|
# to ensure that the extension is downloaded only once.
|
|
#
|
|
# 4. Test that private extensions are only downloaded when they are present in the spec.
|
|
#
|