mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
## Problem Data import fails if the src db has any event triggers, because those can only be restored by a superuser. Specifically imports from Heroku and Supabase are guaranteed to fail. Closes https://github.com/neondatabase/cloud/issues/27353 ## Summary of changes Depends on `pg_dump` patches per each supported PostgreSQL version: - https://github.com/neondatabase/postgres/pull/630 - https://github.com/neondatabase/postgres/pull/629 - https://github.com/neondatabase/postgres/pull/627 - https://github.com/neondatabase/postgres/pull/628
918 lines
33 KiB
Python
918 lines
33 KiB
Python
import base64
|
|
import json
|
|
import time
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from threading import Event
|
|
|
|
import psycopg2
|
|
import psycopg2.errors
|
|
import pytest
|
|
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
|
|
from fixtures.fast_import import FastImport
|
|
from fixtures.log_helper import log
|
|
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres
|
|
from fixtures.pageserver.http import (
|
|
ImportPgdataIdemptencyKey,
|
|
)
|
|
from fixtures.pg_version import PgVersion
|
|
from fixtures.port_distributor import PortDistributor
|
|
from fixtures.remote_storage import MockS3Server, RemoteStorageKind
|
|
from fixtures.utils import (
|
|
run_only_on_default_postgres,
|
|
shared_buffers_for_max_cu,
|
|
skip_in_debug_build,
|
|
wait_until,
|
|
)
|
|
from mypy_boto3_kms import KMSClient
|
|
from mypy_boto3_kms.type_defs import EncryptResponseTypeDef
|
|
from mypy_boto3_s3 import S3Client
|
|
from pytest_httpserver import HTTPServer
|
|
from werkzeug.wrappers.request import Request
|
|
from werkzeug.wrappers.response import Response
|
|
|
|
num_rows = 1000
|
|
|
|
|
|
class RelBlockSize(Enum):
|
|
ONE_STRIPE_SIZE = 1
|
|
TWO_STRPES_PER_SHARD = 2
|
|
MULTIPLE_RELATION_SEGMENTS = 3
|
|
|
|
|
|
smoke_params = [
|
|
# unsharded (the stripe size needs to be given for rel block size calculations)
|
|
*[(None, 1024, s) for s in RelBlockSize],
|
|
# many shards, small stripe size to speed up test
|
|
*[(8, 1024, s) for s in RelBlockSize],
|
|
]
|
|
|
|
|
|
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
|
|
"""
|
|
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
|
|
"""
|
|
assert not vanilla_pg.is_running()
|
|
|
|
path.mkdir()
|
|
# what cplane writes before scheduling fast_import
|
|
specpath = path / "spec.json"
|
|
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
|
|
# what fast_import writes
|
|
vanilla_pg.pgdatadir.rename(path / "pgdata")
|
|
statusdir = path / "status"
|
|
statusdir.mkdir()
|
|
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
|
|
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
|
|
|
|
|
|
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
|
|
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
|
|
def test_pgdata_import_smoke(
|
|
vanilla_pg: VanillaPostgres,
|
|
neon_env_builder: NeonEnvBuilder,
|
|
shard_count: int | None,
|
|
stripe_size: int,
|
|
rel_block_size: RelBlockSize,
|
|
make_httpserver: HTTPServer,
|
|
):
|
|
#
|
|
# Setup fake control plane for import progress
|
|
#
|
|
import_completion_signaled = Event()
|
|
|
|
def handler(request: Request) -> Response:
|
|
log.info(f"control plane /import_complete request: {request.json}")
|
|
import_completion_signaled.set()
|
|
return Response(json.dumps({}), status=200)
|
|
|
|
cplane_mgmt_api_server = make_httpserver
|
|
cplane_mgmt_api_server.expect_request(
|
|
"/storage/api/v1/import_complete", method="PUT"
|
|
).respond_with_handler(handler)
|
|
|
|
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
|
|
|
neon_env_builder.control_plane_hooks_api = (
|
|
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
|
|
)
|
|
|
|
env = neon_env_builder.init_start()
|
|
|
|
# The test needs LocalFs support, which is only built in testing mode.
|
|
env.pageserver.is_testing_enabled_or_skip()
|
|
|
|
env.pageserver.stop()
|
|
env.pageserver.start()
|
|
|
|
# By default our tests run with a tiny shared_buffers=1MB setting. That
|
|
# doesn't allow any prefetching on v17 and above, where the new streaming
|
|
# read machinery keeps buffers pinned while prefetching them. Use a higher
|
|
# setting to enable prefetching and speed up the tests
|
|
# use shared_buffers size like in production for 8 CU compute
|
|
ep_config = [f"shared_buffers={shared_buffers_for_max_cu(8.0)}"]
|
|
|
|
#
|
|
# Put data in vanilla pg
|
|
#
|
|
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
|
|
|
log.info("create relblock data")
|
|
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
|
|
target_relblock_size = stripe_size * 8192
|
|
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
|
|
target_relblock_size = (shard_count or 1) * stripe_size * 8192 * 2
|
|
elif rel_block_size == RelBlockSize.MULTIPLE_RELATION_SEGMENTS:
|
|
# Postgres uses a 1GiB segment size, fixed at compile time, so we must use >2GB of data
|
|
# to exercise multiple segments.
|
|
target_relblock_size = int(((2.333 * 1024 * 1024 * 1024) // 8192) * 8192)
|
|
else:
|
|
raise ValueError
|
|
|
|
# fillfactor so we don't need to produce that much data
|
|
# 900 byte per row is > 10% => 1 row per page
|
|
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
|
|
|
nrows = 0
|
|
while True:
|
|
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
|
log.info(
|
|
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
|
)
|
|
if relblock_size >= target_relblock_size:
|
|
break
|
|
addrows = int((target_relblock_size - relblock_size) // 8192)
|
|
assert addrows >= 1, "forward progress"
|
|
vanilla_pg.safe_psql(
|
|
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
|
)
|
|
nrows += addrows
|
|
expect_nrows = nrows
|
|
expect_sum = (
|
|
(nrows) * (nrows + 1) // 2
|
|
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
|
|
|
|
def validate_vanilla_equivalence(ep):
|
|
# TODO: would be nicer to just compare pgdump
|
|
|
|
# Enable IO concurrency for batching on large sequential scan, to avoid making
|
|
# this test unnecessarily onerous on CPU. Especially on debug mode, it's still
|
|
# pretty onerous though, so increase statement_timeout to avoid timeouts.
|
|
assert ep.safe_psql_many(
|
|
[
|
|
"set effective_io_concurrency=32;",
|
|
"SET statement_timeout='300s';",
|
|
"select count(*), sum(data::bigint)::bigint from t",
|
|
]
|
|
) == [[], [], [(expect_nrows, expect_sum)]]
|
|
|
|
validate_vanilla_equivalence(vanilla_pg)
|
|
|
|
vanilla_pg.stop()
|
|
|
|
#
|
|
# We have a Postgres data directory now.
|
|
# Make a localfs remote storage that looks like how after `fast_import` ran.
|
|
# TODO: actually exercise fast_import here
|
|
# TODO: test s3 remote storage
|
|
#
|
|
importbucket_path = neon_env_builder.repo_dir / "importbucket"
|
|
mock_import_bucket(vanilla_pg, importbucket_path)
|
|
|
|
#
|
|
# Do the import
|
|
#
|
|
|
|
tenant_id = TenantId.generate()
|
|
env.storage_controller.tenant_create(
|
|
tenant_id, shard_count=shard_count, shard_stripe_size=stripe_size
|
|
)
|
|
|
|
timeline_id = TimelineId.generate()
|
|
log.info("starting import")
|
|
start = time.monotonic()
|
|
|
|
idempotency = ImportPgdataIdemptencyKey.random()
|
|
log.info(f"idempotency key {idempotency}")
|
|
# TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop
|
|
# and check for 429
|
|
|
|
import_branch_name = "imported"
|
|
env.storage_controller.timeline_create(
|
|
tenant_id,
|
|
{
|
|
"new_timeline_id": str(timeline_id),
|
|
"import_pgdata": {
|
|
"idempotency_key": str(idempotency),
|
|
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
|
|
},
|
|
},
|
|
)
|
|
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
|
|
|
|
def cplane_notified():
|
|
assert import_completion_signaled.is_set()
|
|
|
|
# Generous timeout for the MULTIPLE_RELATION_SEGMENTS test variants
|
|
wait_until(cplane_notified, timeout=90)
|
|
|
|
import_duration = time.monotonic() - start
|
|
log.info(f"import complete; duration={import_duration:.2f}s")
|
|
|
|
#
|
|
# Get some timeline details for later.
|
|
#
|
|
locations = env.storage_controller.locate(tenant_id)
|
|
[shard_zero] = [
|
|
loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0
|
|
]
|
|
shard_zero_ps = env.get_pageserver(shard_zero["node_id"])
|
|
shard_zero_http = shard_zero_ps.http_client()
|
|
shard_zero_timeline_info = shard_zero_http.timeline_detail(shard_zero["shard_id"], timeline_id)
|
|
initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"])
|
|
min_readable_lsn = Lsn(shard_zero_timeline_info["min_readable_lsn"])
|
|
last_record_lsn = Lsn(shard_zero_timeline_info["last_record_lsn"])
|
|
disk_consistent_lsn = Lsn(shard_zero_timeline_info["disk_consistent_lsn"])
|
|
_remote_consistent_lsn = Lsn(shard_zero_timeline_info["remote_consistent_lsn"])
|
|
remote_consistent_lsn_visible = Lsn(shard_zero_timeline_info["remote_consistent_lsn_visible"])
|
|
# assert remote_consistent_lsn_visible == remote_consistent_lsn TODO: this fails initially and after restart, presumably because `UploadQueue::clean.1` is still `None`
|
|
assert remote_consistent_lsn_visible == disk_consistent_lsn
|
|
assert initdb_lsn == min_readable_lsn
|
|
assert disk_consistent_lsn == initdb_lsn + 8
|
|
assert last_record_lsn == disk_consistent_lsn
|
|
# TODO: assert these values are the same everywhere
|
|
|
|
#
|
|
# Validate the resulting remote storage state.
|
|
#
|
|
|
|
#
|
|
# Validate the imported data
|
|
#
|
|
|
|
ro_endpoint = env.endpoints.create_start(
|
|
branch_name=import_branch_name,
|
|
endpoint_id="ro",
|
|
tenant_id=tenant_id,
|
|
lsn=last_record_lsn,
|
|
config_lines=ep_config,
|
|
)
|
|
|
|
validate_vanilla_equivalence(ro_endpoint)
|
|
|
|
# ensure the import survives restarts
|
|
ro_endpoint.stop()
|
|
env.pageserver.stop(immediate=True)
|
|
env.pageserver.start()
|
|
ro_endpoint.start()
|
|
validate_vanilla_equivalence(ro_endpoint)
|
|
|
|
#
|
|
# validate the layer files in each shard only have the shard-specific data
|
|
# (the implementation would be functional but not efficient without this characteristic)
|
|
#
|
|
|
|
shards = env.storage_controller.locate(tenant_id)
|
|
for shard in shards:
|
|
shard_ps = env.get_pageserver(shard["node_id"])
|
|
result = shard_ps.timeline_scan_no_disposable_keys(shard["shard_id"], timeline_id)
|
|
assert result.tally.disposable_count == 0
|
|
assert result.tally.not_disposable_count > 0, (
|
|
"sanity check, each shard should have some data"
|
|
)
|
|
|
|
#
|
|
# validate that we can write
|
|
#
|
|
rw_endpoint = env.endpoints.create_start(
|
|
branch_name=import_branch_name,
|
|
endpoint_id="rw",
|
|
tenant_id=tenant_id,
|
|
config_lines=ep_config,
|
|
)
|
|
rw_endpoint.safe_psql("create table othertable(values text)")
|
|
rw_lsn = Lsn(rw_endpoint.safe_psql_scalar("select pg_current_wal_flush_lsn()"))
|
|
|
|
# TODO: consider using `class Workload` here
|
|
# to do compaction and whatnot?
|
|
|
|
#
|
|
# validate that we can branch (important use case)
|
|
#
|
|
|
|
# ... at the tip
|
|
_ = env.create_branch(
|
|
new_branch_name="br-tip",
|
|
ancestor_branch_name=import_branch_name,
|
|
tenant_id=tenant_id,
|
|
ancestor_start_lsn=rw_lsn,
|
|
)
|
|
br_tip_endpoint = env.endpoints.create_start(
|
|
branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id, config_lines=ep_config
|
|
)
|
|
validate_vanilla_equivalence(br_tip_endpoint)
|
|
br_tip_endpoint.safe_psql("select * from othertable")
|
|
|
|
# ... at the initdb lsn
|
|
_ = env.create_branch(
|
|
new_branch_name="br-initdb",
|
|
ancestor_branch_name=import_branch_name,
|
|
tenant_id=tenant_id,
|
|
ancestor_start_lsn=initdb_lsn,
|
|
)
|
|
br_initdb_endpoint = env.endpoints.create_start(
|
|
branch_name="br-initdb",
|
|
endpoint_id="br-initdb-ro",
|
|
tenant_id=tenant_id,
|
|
config_lines=ep_config,
|
|
)
|
|
validate_vanilla_equivalence(br_initdb_endpoint)
|
|
with pytest.raises(psycopg2.errors.UndefinedTable):
|
|
br_initdb_endpoint.safe_psql("select * from othertable")
|
|
|
|
|
|
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
|
def test_import_completion_on_restart(
|
|
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
|
|
):
|
|
"""
|
|
Validate that the storage controller delivers the import completion notification
|
|
eventually even if it was restarted when the import initially completed.
|
|
"""
|
|
# Set up mock control plane HTTP server to listen for import completions
|
|
import_completion_signaled = Event()
|
|
|
|
def handler(request: Request) -> Response:
|
|
log.info(f"control plane /import_complete request: {request.json}")
|
|
import_completion_signaled.set()
|
|
return Response(json.dumps({}), status=200)
|
|
|
|
cplane_mgmt_api_server = make_httpserver
|
|
cplane_mgmt_api_server.expect_request(
|
|
"/storage/api/v1/import_complete", method="PUT"
|
|
).respond_with_handler(handler)
|
|
|
|
# Plug the cplane mock in
|
|
neon_env_builder.control_plane_hooks_api = (
|
|
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
|
|
)
|
|
|
|
# The import will specifiy a local filesystem path mocking remote storage
|
|
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
|
|
|
vanilla_pg.start()
|
|
vanilla_pg.stop()
|
|
|
|
env = neon_env_builder.init_configs()
|
|
env.start()
|
|
|
|
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
|
|
mock_import_bucket(vanilla_pg, importbucket_path)
|
|
|
|
tenant_id = TenantId.generate()
|
|
timeline_id = TimelineId.generate()
|
|
idempotency = ImportPgdataIdemptencyKey.random()
|
|
|
|
# Pause before sending the notification
|
|
failpoint_name = "timeline-import-pre-cplane-notification"
|
|
env.storage_controller.configure_failpoints((failpoint_name, "pause"))
|
|
|
|
env.storage_controller.tenant_create(tenant_id)
|
|
env.storage_controller.timeline_create(
|
|
tenant_id,
|
|
{
|
|
"new_timeline_id": str(timeline_id),
|
|
"import_pgdata": {
|
|
"idempotency_key": str(idempotency),
|
|
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
|
|
},
|
|
},
|
|
)
|
|
|
|
def hit_failpoint():
|
|
log.info("Checking log for pattern...")
|
|
try:
|
|
assert env.storage_controller.log_contains(f".*at failpoint {failpoint_name}.*")
|
|
except Exception:
|
|
log.exception("Failed to find pattern in log")
|
|
raise
|
|
|
|
wait_until(hit_failpoint)
|
|
assert not import_completion_signaled.is_set()
|
|
|
|
# Restart the storage controller before signalling control plane.
|
|
# This clears the failpoint and we expect that the import start-up reconciliation
|
|
# kicks in and notifies cplane.
|
|
env.storage_controller.stop()
|
|
env.storage_controller.start()
|
|
|
|
def cplane_notified():
|
|
assert import_completion_signaled.is_set()
|
|
|
|
wait_until(cplane_notified)
|
|
|
|
|
|
def test_fast_import_with_pageserver_ingest(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
pg_distrib_dir: Path,
|
|
pg_version: PgVersion,
|
|
mock_s3_server: MockS3Server,
|
|
mock_kms: KMSClient,
|
|
mock_s3_client: S3Client,
|
|
neon_env_builder: NeonEnvBuilder,
|
|
make_httpserver: HTTPServer,
|
|
):
|
|
# Prepare KMS and S3
|
|
key_response = mock_kms.create_key(
|
|
Description="Test key",
|
|
KeyUsage="ENCRYPT_DECRYPT",
|
|
Origin="AWS_KMS",
|
|
)
|
|
key_id = key_response["KeyMetadata"]["KeyId"]
|
|
|
|
def encrypt(x: str) -> EncryptResponseTypeDef:
|
|
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
|
|
|
# Start source postgres and ingest data
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
|
|
|
# Setup pageserver and fake cplane for import progress
|
|
import_completion_signaled = Event()
|
|
|
|
def handler(request: Request) -> Response:
|
|
log.info(f"control plane /import_complete request: {request.json}")
|
|
import_completion_signaled.set()
|
|
return Response(json.dumps({}), status=200)
|
|
|
|
cplane_mgmt_api_server = make_httpserver
|
|
cplane_mgmt_api_server.expect_request(
|
|
"/storage/api/v1/import_complete", method="PUT"
|
|
).respond_with_handler(handler)
|
|
|
|
neon_env_builder.control_plane_hooks_api = (
|
|
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
|
|
)
|
|
|
|
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
|
|
env = neon_env_builder.init_start()
|
|
|
|
env.pageserver.patch_config_toml_nonrecursive(
|
|
{
|
|
# because import_pgdata code uses this endpoint, not the one in common remote storage config
|
|
# TODO: maybe use common remote_storage config in pageserver?
|
|
"import_pgdata_aws_endpoint_url": env.s3_mock_server.endpoint(),
|
|
}
|
|
)
|
|
env.pageserver.stop()
|
|
env.pageserver.start()
|
|
|
|
# Encrypt connstrings and put spec into S3
|
|
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
|
|
spec = {
|
|
"encryption_secret": {"KMS": {"key_id": key_id}},
|
|
"source_connstring_ciphertext_base64": base64.b64encode(
|
|
source_connstring_encrypted["CiphertextBlob"]
|
|
).decode("utf-8"),
|
|
"project_id": "someproject",
|
|
"branch_id": "somebranch",
|
|
}
|
|
|
|
bucket = "test-bucket"
|
|
key_prefix = "test-prefix"
|
|
mock_s3_client.create_bucket(Bucket=bucket)
|
|
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
|
|
|
|
# Create timeline with import_pgdata
|
|
tenant_id = TenantId.generate()
|
|
env.storage_controller.tenant_create(tenant_id)
|
|
|
|
timeline_id = TimelineId.generate()
|
|
log.info("starting import")
|
|
start = time.monotonic()
|
|
|
|
idempotency = ImportPgdataIdemptencyKey.random()
|
|
log.info(f"idempotency key {idempotency}")
|
|
# TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop
|
|
# and check for 429
|
|
|
|
import_branch_name = "imported"
|
|
env.storage_controller.timeline_create(
|
|
tenant_id,
|
|
{
|
|
"new_timeline_id": str(timeline_id),
|
|
"import_pgdata": {
|
|
"idempotency_key": str(idempotency),
|
|
"location": {
|
|
"AwsS3": {
|
|
"region": env.s3_mock_server.region(),
|
|
"bucket": bucket,
|
|
"key": key_prefix,
|
|
}
|
|
},
|
|
},
|
|
},
|
|
)
|
|
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
|
|
|
|
# Run fast_import
|
|
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
|
|
pg_port = port_distributor.get_port()
|
|
fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}")
|
|
|
|
pgdata_status_obj = mock_s3_client.get_object(Bucket=bucket, Key=f"{key_prefix}/status/pgdata")
|
|
pgdata_status = pgdata_status_obj["Body"].read().decode("utf-8")
|
|
assert json.loads(pgdata_status) == {"done": True}, f"got status: {pgdata_status}"
|
|
|
|
job_status_obj = mock_s3_client.get_object(
|
|
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
|
|
)
|
|
job_status = job_status_obj["Body"].read().decode("utf-8")
|
|
assert json.loads(job_status) == {
|
|
"command": "pgdata",
|
|
"done": True,
|
|
}, f"got status: {job_status}"
|
|
|
|
vanilla_pg.stop()
|
|
|
|
def validate_vanilla_equivalence(ep):
|
|
res = ep.safe_psql("SELECT count(*), sum(a) FROM foo;", dbname="neondb")
|
|
assert res[0] == (10, 55), f"got result: {res}"
|
|
|
|
# Sanity check that data in pgdata is expected:
|
|
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
|
|
with VanillaPostgres(
|
|
fast_import.workdir / "pgdata", pgbin, pg_port, False
|
|
) as new_pgdata_vanilla_pg:
|
|
new_pgdata_vanilla_pg.start()
|
|
|
|
# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
|
|
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
|
|
validate_vanilla_equivalence(conn)
|
|
|
|
def cplane_notified():
|
|
assert import_completion_signaled.is_set()
|
|
|
|
wait_until(cplane_notified, timeout=60)
|
|
|
|
import_duration = time.monotonic() - start
|
|
log.info(f"import complete; duration={import_duration:.2f}s")
|
|
|
|
ep = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
|
|
|
|
# check that data is there
|
|
validate_vanilla_equivalence(ep)
|
|
|
|
# check that we can do basic ops
|
|
|
|
ep.safe_psql("create table othertable(values text)", dbname="neondb")
|
|
rw_lsn = Lsn(ep.safe_psql_scalar("select pg_current_wal_flush_lsn()"))
|
|
ep.stop()
|
|
|
|
# ... at the tip
|
|
_ = env.create_branch(
|
|
new_branch_name="br-tip",
|
|
ancestor_branch_name=import_branch_name,
|
|
tenant_id=tenant_id,
|
|
ancestor_start_lsn=rw_lsn,
|
|
)
|
|
br_tip_endpoint = env.endpoints.create_start(
|
|
branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id
|
|
)
|
|
validate_vanilla_equivalence(br_tip_endpoint)
|
|
br_tip_endpoint.safe_psql("select * from othertable", dbname="neondb")
|
|
br_tip_endpoint.stop()
|
|
|
|
# ... at the initdb lsn
|
|
locations = env.storage_controller.locate(tenant_id)
|
|
[shard_zero] = [
|
|
loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0
|
|
]
|
|
shard_zero_ps = env.get_pageserver(shard_zero["node_id"])
|
|
shard_zero_timeline_info = shard_zero_ps.http_client().timeline_detail(
|
|
shard_zero["shard_id"], timeline_id
|
|
)
|
|
initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"])
|
|
_ = env.create_branch(
|
|
new_branch_name="br-initdb",
|
|
ancestor_branch_name=import_branch_name,
|
|
tenant_id=tenant_id,
|
|
ancestor_start_lsn=initdb_lsn,
|
|
)
|
|
br_initdb_endpoint = env.endpoints.create_start(
|
|
branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id
|
|
)
|
|
validate_vanilla_equivalence(br_initdb_endpoint)
|
|
with pytest.raises(psycopg2.errors.UndefinedTable):
|
|
br_initdb_endpoint.safe_psql("select * from othertable", dbname="neondb")
|
|
br_initdb_endpoint.stop()
|
|
|
|
env.pageserver.stop(immediate=True)
|
|
|
|
|
|
def test_fast_import_binary(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
):
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
|
|
|
pg_port = port_distributor.get_port()
|
|
fast_import.run_pgdata(pg_port=pg_port, source_connection_string=vanilla_pg.connstr())
|
|
vanilla_pg.stop()
|
|
|
|
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
|
|
with VanillaPostgres(
|
|
fast_import.workdir / "pgdata", pgbin, pg_port, False
|
|
) as new_pgdata_vanilla_pg:
|
|
new_pgdata_vanilla_pg.start()
|
|
|
|
# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
|
|
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
|
|
res = conn.safe_psql("SELECT count(*) FROM foo;")
|
|
log.info(f"Result: {res}")
|
|
assert res[0][0] == 10
|
|
|
|
|
|
def test_fast_import_event_triggers(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
):
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("""
|
|
CREATE FUNCTION test_event_trigger_for_drops()
|
|
RETURNS event_trigger LANGUAGE plpgsql AS $$
|
|
DECLARE
|
|
obj record;
|
|
BEGIN
|
|
FOR obj IN SELECT * FROM pg_event_trigger_dropped_objects()
|
|
LOOP
|
|
RAISE NOTICE '% dropped object: % %.% %',
|
|
tg_tag,
|
|
obj.object_type,
|
|
obj.schema_name,
|
|
obj.object_name,
|
|
obj.object_identity;
|
|
END LOOP;
|
|
END
|
|
$$;
|
|
|
|
CREATE EVENT TRIGGER test_event_trigger_for_drops
|
|
ON sql_drop
|
|
EXECUTE PROCEDURE test_event_trigger_for_drops();
|
|
""")
|
|
|
|
pg_port = port_distributor.get_port()
|
|
p = fast_import.run_pgdata(pg_port=pg_port, source_connection_string=vanilla_pg.connstr())
|
|
assert p.returncode == 0
|
|
|
|
vanilla_pg.stop()
|
|
|
|
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
|
|
with VanillaPostgres(
|
|
fast_import.workdir / "pgdata", pgbin, pg_port, False
|
|
) as new_pgdata_vanilla_pg:
|
|
new_pgdata_vanilla_pg.start()
|
|
|
|
# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
|
|
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
|
|
res = conn.safe_psql("SELECT count(*) FROM pg_event_trigger;")
|
|
log.info(f"Result: {res}")
|
|
assert res[0][0] == 0, f"Neon does not support importing event triggers, got: {res[0][0]}"
|
|
|
|
|
|
def test_fast_import_restore_to_connstring(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
pg_distrib_dir: Path,
|
|
pg_version: PgVersion,
|
|
):
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
|
|
|
pgdatadir = test_output_dir / "destination-pgdata"
|
|
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
|
port = port_distributor.get_port()
|
|
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
|
|
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
|
|
destination_vanilla_pg.start()
|
|
|
|
# create another database & role and try to restore there
|
|
destination_vanilla_pg.safe_psql("""
|
|
CREATE ROLE testrole WITH
|
|
LOGIN
|
|
PASSWORD 'testpassword'
|
|
NOSUPERUSER
|
|
NOCREATEDB
|
|
NOCREATEROLE;
|
|
""")
|
|
destination_vanilla_pg.safe_psql("CREATE DATABASE testdb OWNER testrole;")
|
|
|
|
destination_connstring = destination_vanilla_pg.connstr(
|
|
dbname="testdb", user="testrole", password="testpassword"
|
|
)
|
|
fast_import.run_dump_restore(
|
|
source_connection_string=vanilla_pg.connstr(),
|
|
destination_connection_string=destination_connstring,
|
|
)
|
|
vanilla_pg.stop()
|
|
conn = PgProtocol(dsn=destination_connstring)
|
|
res = conn.safe_psql("SELECT count(*) FROM foo;")
|
|
log.info(f"Result: {res}")
|
|
assert res[0][0] == 10
|
|
|
|
|
|
def test_fast_import_restore_to_connstring_from_s3_spec(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
pg_distrib_dir: Path,
|
|
pg_version: PgVersion,
|
|
mock_s3_server: MockS3Server,
|
|
mock_kms: KMSClient,
|
|
mock_s3_client: S3Client,
|
|
):
|
|
# Prepare KMS and S3
|
|
key_response = mock_kms.create_key(
|
|
Description="Test key",
|
|
KeyUsage="ENCRYPT_DECRYPT",
|
|
Origin="AWS_KMS",
|
|
)
|
|
key_id = key_response["KeyMetadata"]["KeyId"]
|
|
|
|
def encrypt(x: str) -> EncryptResponseTypeDef:
|
|
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
|
|
|
# Start source postgres and ingest data
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
|
|
|
# Start target postgres
|
|
pgdatadir = test_output_dir / "destination-pgdata"
|
|
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
|
port = port_distributor.get_port()
|
|
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
|
|
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
|
|
destination_vanilla_pg.start()
|
|
|
|
# Encrypt connstrings and put spec into S3
|
|
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
|
|
destination_connstring_encrypted = encrypt(destination_vanilla_pg.connstr())
|
|
spec = {
|
|
"encryption_secret": {"KMS": {"key_id": key_id}},
|
|
"source_connstring_ciphertext_base64": base64.b64encode(
|
|
source_connstring_encrypted["CiphertextBlob"]
|
|
).decode("utf-8"),
|
|
"destination_connstring_ciphertext_base64": base64.b64encode(
|
|
destination_connstring_encrypted["CiphertextBlob"]
|
|
).decode("utf-8"),
|
|
}
|
|
|
|
bucket = "test-bucket"
|
|
key_prefix = "test-prefix"
|
|
mock_s3_client.create_bucket(Bucket=bucket)
|
|
mock_s3_client.put_object(
|
|
Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)
|
|
)
|
|
|
|
# Run fast_import
|
|
fast_import.set_aws_creds(
|
|
mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"}
|
|
)
|
|
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
|
|
|
|
job_status_obj = mock_s3_client.get_object(
|
|
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
|
|
)
|
|
job_status = job_status_obj["Body"].read().decode("utf-8")
|
|
assert json.loads(job_status) == {
|
|
"done": True,
|
|
"command": "dump-restore",
|
|
}, f"got status: {job_status}"
|
|
vanilla_pg.stop()
|
|
|
|
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
|
|
log.info(f"Result: {res}")
|
|
assert res[0][0] == 10
|
|
|
|
|
|
def test_fast_import_restore_to_connstring_error_to_s3_bad_destination(
|
|
test_output_dir,
|
|
vanilla_pg: VanillaPostgres,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
pg_distrib_dir: Path,
|
|
pg_version: PgVersion,
|
|
mock_s3_server: MockS3Server,
|
|
mock_kms: KMSClient,
|
|
mock_s3_client: S3Client,
|
|
):
|
|
# Prepare KMS and S3
|
|
key_response = mock_kms.create_key(
|
|
Description="Test key",
|
|
KeyUsage="ENCRYPT_DECRYPT",
|
|
Origin="AWS_KMS",
|
|
)
|
|
key_id = key_response["KeyMetadata"]["KeyId"]
|
|
|
|
def encrypt(x: str) -> EncryptResponseTypeDef:
|
|
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
|
|
|
# Start source postgres and ingest data
|
|
vanilla_pg.start()
|
|
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
|
|
|
# Encrypt connstrings and put spec into S3
|
|
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
|
|
destination_connstring_encrypted = encrypt("postgres://random:connection@string:5432/neondb")
|
|
spec = {
|
|
"encryption_secret": {"KMS": {"key_id": key_id}},
|
|
"source_connstring_ciphertext_base64": base64.b64encode(
|
|
source_connstring_encrypted["CiphertextBlob"]
|
|
).decode("utf-8"),
|
|
"destination_connstring_ciphertext_base64": base64.b64encode(
|
|
destination_connstring_encrypted["CiphertextBlob"]
|
|
).decode("utf-8"),
|
|
}
|
|
|
|
bucket = "test-bucket"
|
|
key_prefix = "test-prefix"
|
|
mock_s3_client.create_bucket(Bucket=bucket)
|
|
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
|
|
|
|
# Run fast_import
|
|
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
|
|
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
|
|
|
|
job_status_obj = mock_s3_client.get_object(
|
|
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
|
|
)
|
|
job_status = job_status_obj["Body"].read().decode("utf-8")
|
|
assert json.loads(job_status) == {
|
|
"command": "dump-restore",
|
|
"done": False,
|
|
"error": "pg_restore failed",
|
|
}, f"got status: {job_status}"
|
|
vanilla_pg.stop()
|
|
|
|
|
|
def test_fast_import_restore_to_connstring_error_to_s3_kms_error(
|
|
test_output_dir,
|
|
port_distributor: PortDistributor,
|
|
fast_import: FastImport,
|
|
pg_distrib_dir: Path,
|
|
pg_version: PgVersion,
|
|
mock_s3_server: MockS3Server,
|
|
mock_kms: KMSClient,
|
|
mock_s3_client: S3Client,
|
|
):
|
|
# Prepare KMS and S3
|
|
key_response = mock_kms.create_key(
|
|
Description="Test key",
|
|
KeyUsage="ENCRYPT_DECRYPT",
|
|
Origin="AWS_KMS",
|
|
)
|
|
key_id = key_response["KeyMetadata"]["KeyId"]
|
|
|
|
def encrypt(x: str) -> EncryptResponseTypeDef:
|
|
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
|
|
|
# Encrypt connstrings and put spec into S3
|
|
spec = {
|
|
"encryption_secret": {"KMS": {"key_id": key_id}},
|
|
"source_connstring_ciphertext_base64": base64.b64encode(b"invalid encrypted string").decode(
|
|
"utf-8"
|
|
),
|
|
}
|
|
|
|
bucket = "test-bucket"
|
|
key_prefix = "test-prefix"
|
|
mock_s3_client.create_bucket(Bucket=bucket)
|
|
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
|
|
|
|
# Run fast_import
|
|
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
|
|
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
|
|
|
|
job_status_obj = mock_s3_client.get_object(
|
|
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
|
|
)
|
|
job_status = job_status_obj["Body"].read().decode("utf-8")
|
|
assert json.loads(job_status) == {
|
|
"command": "dump-restore",
|
|
"done": False,
|
|
"error": "decrypt source connection string",
|
|
}, f"got status: {job_status}"
|