fast imports: initial Importer and Storage changes (#9218)

Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
Co-authored-by: Stas Kelvic <stas@neon.tech>

# Context

This PR contains PoC-level changes for a product feature that allows
onboarding large databases into Neon without going through the regular
data path.

# Changes

This internal RFC provides all the context
* https://github.com/neondatabase/cloud/pull/19799

In the language of the RFC, this PR covers

* the Importer code (`fast_import`) 
* all the Pageserver changes (mgmt API changes, flow implementation,
etc)
* a basic test for the Pageserver changes

# Reviewing

As acknowledged in the RFC, the code added in this PR is not ready for
general availability.
Also, the **architecture is not to be discussed in this PR**, but in the
RFC and associated Slack channel instead.

Reviewers of this PR should take that into consideration.
The quality bar to apply during review depends on what area of the code
is being reviewed:

* Importer code (`fast_import`): practically anything goes
* Core flow (`flow.rs`):
* Malicious input data must be expected and the existing threat models
apply.
* The code must not be safe to execute on *dedicated* Pageserver
instances:
* This means in particular that tenants *on other* Pageserver instances
must not be affected negatively wrt data confidentiality, integrity or
availability.
* Other code: the usual quality bar
* Pay special attention to correct use of gate guards, timeline
cancellation in all places during shutdown & migration, etc.
* Consider the broader system impact; if you find potentially
problematic interactions with Storage features that were not covered in
the RFC, bring that up during the review.

I recommend submitting three separate reviews, for the three high-level
areas with different quality bars.


# References

(Internal-only)

* refs https://github.com/neondatabase/cloud/issues/17507
* refs https://github.com/neondatabase/company_projects/issues/293
* refs https://github.com/neondatabase/company_projects/issues/309
* refs https://github.com/neondatabase/cloud/issues/20646

---------

Co-authored-by: Stas Kelvich <stas.kelvich@gmail.com>
Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
Co-authored-by: John Spray <john@neon.tech>
This commit is contained in:
Christian Schwarz
2024-11-22 23:47:06 +01:00
committed by GitHub
parent 3245f7b88d
commit 450be26bbb
39 changed files with 3368 additions and 182 deletions

View File

@@ -190,6 +190,25 @@ class TenantTimelineId:
)
@dataclass
class ShardIndex:
shard_number: int
shard_count: int
# cf impl Display for ShardIndex
@override
def __str__(self) -> str:
return f"{self.shard_number:02x}{self.shard_count:02x}"
@classmethod
def parse(cls: type[ShardIndex], input: str) -> ShardIndex:
assert len(input) == 4
return cls(
shard_number=int(input[0:2], 16),
shard_count=int(input[2:4], 16),
)
class TenantShardId:
def __init__(self, tenant_id: TenantId, shard_number: int, shard_count: int):
self.tenant_id = tenant_id
@@ -222,6 +241,10 @@ class TenantShardId:
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
return str(self.tenant_id)
@property
def shard_index(self) -> ShardIndex:
return ShardIndex(self.shard_number, self.shard_count)
@override
def __repr__(self):
return self.__str__()

View File

@@ -1883,6 +1883,20 @@ class NeonStorageController(MetricsGetter, LogUtils):
response.raise_for_status()
log.info(f"tenant_create success: {response.json()}")
def timeline_create(
self,
tenant_id: TenantId,
body: dict[str, Any],
):
response = self.request(
"POST",
f"{self.api}/v1/tenant/{tenant_id}/timeline",
json=body,
headers=self.headers(TokenScope.PAGE_SERVER_API),
)
response.raise_for_status()
log.info(f"timeline_create success: {response.json()}")
def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]:
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr": str, "listen_http_port": int}

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
import dataclasses
import json
import random
import string
import time
from collections import defaultdict
from dataclasses import dataclass
@@ -10,7 +14,14 @@ import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineArchivalState, TimelineId
from fixtures.common_types import (
Id,
Lsn,
TenantId,
TenantShardId,
TimelineArchivalState,
TimelineId,
)
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.pg_version import PgVersion
@@ -24,6 +35,69 @@ class PageserverApiException(Exception):
self.status_code = status_code
@dataclass
class ImportPgdataIdemptencyKey:
key: str
@staticmethod
def random() -> ImportPgdataIdemptencyKey:
return ImportPgdataIdemptencyKey(
"".join(random.choices(string.ascii_letters + string.digits, k=20))
)
@dataclass
class LocalFs:
path: str
@dataclass
class AwsS3:
region: str
bucket: str
key: str
@dataclass
class ImportPgdataLocation:
LocalFs: None | LocalFs = None
AwsS3: None | AwsS3 = None
@dataclass
class TimelineCreateRequestModeImportPgdata:
location: ImportPgdataLocation
idempotency_key: ImportPgdataIdemptencyKey
@dataclass
class TimelineCreateRequestMode:
Branch: None | dict[str, Any] = None
Bootstrap: None | dict[str, Any] = None
ImportPgdata: None | TimelineCreateRequestModeImportPgdata = None
@dataclass
class TimelineCreateRequest:
new_timeline_id: TimelineId
mode: TimelineCreateRequestMode
def to_json(self) -> str:
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o) and not isinstance(o, type):
return dataclasses.asdict(o)
elif isinstance(o, Id):
return o.id.hex()
return super().default(o)
# mode is flattened
this = dataclasses.asdict(self)
mode = this.pop("mode")
this.update(mode)
return json.dumps(self, cls=EnhancedJSONEncoder)
class TimelineCreate406(PageserverApiException):
def __init__(self, res: requests.Response):
assert res.status_code == 406

View File

@@ -674,6 +674,13 @@ def run_only_on_default_postgres(reason: str):
)
def run_only_on_postgres(versions: Iterable[PgVersion], reason: str):
return pytest.mark.skipif(
PgVersion(os.getenv("DEFAULT_PG_VERSION", PgVersion.DEFAULT)) not in versions,
reason=reason,
)
def skip_in_debug_build(reason: str):
return pytest.mark.skipif(
os.getenv("BUILD_TYPE", "debug") == "debug",

View File

@@ -0,0 +1,307 @@
import json
import re
import time
from enum import Enum
import psycopg2
import psycopg2.errors
import pytest
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder, VanillaPostgres
from fixtures.pageserver.http import (
ImportPgdataIdemptencyKey,
PageserverApiException,
)
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import RemoteStorageKind
from fixtures.utils import run_only_on_postgres
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
num_rows = 1000
class RelBlockSize(Enum):
ONE_STRIPE_SIZE = 1
TWO_STRPES_PER_SHARD = 2
MULTIPLE_RELATION_SEGMENTS = 3
smoke_params = [
# unsharded (the stripe size needs to be given for rel block size calculations)
*[(None, 1024, s) for s in RelBlockSize],
# many shards, small stripe size to speed up test
*[(8, 1024, s) for s in RelBlockSize],
]
@run_only_on_postgres(
[PgVersion.V14, PgVersion.V15, PgVersion.V16],
"newer control file catalog version and struct format isn't supported",
)
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
def test_pgdata_import_smoke(
vanilla_pg: VanillaPostgres,
neon_env_builder: NeonEnvBuilder,
shard_count: int | None,
stripe_size: int,
rel_block_size: RelBlockSize,
make_httpserver: HTTPServer,
):
#
# Setup fake control plane for import progress
#
def handler(request: Request) -> Response:
log.info(f"control plane request: {request.json}")
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler)
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
env = neon_env_builder.init_start()
env.pageserver.patch_config_toml_nonrecursive(
{
"import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api"
}
)
env.pageserver.stop()
env.pageserver.start()
#
# Put data in vanilla pg
#
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
log.info("create relblock data")
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
target_relblock_size = stripe_size * 8192
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
target_relblock_size = (shard_count or 1) * stripe_size * 8192 * 2
elif rel_block_size == RelBlockSize.MULTIPLE_RELATION_SEGMENTS:
target_relblock_size = int(((2.333 * 1024 * 1024 * 1024) // 8192) * 8192)
else:
raise ValueError
# fillfactor so we don't need to produce that much data
# 900 byte per row is > 10% => 1 row per page
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size/8192} pages (target: {target_relblock_size//8192}) pages"
)
if relblock_size >= target_relblock_size:
break
addrows = int((target_relblock_size - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(f"insert into t select generate_series({nrows+1}, {nrows + addrows})")
nrows += addrows
expect_nrows = nrows
expect_sum = (
(nrows) * (nrows + 1) // 2
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
def validate_vanilla_equivalence(ep):
# TODO: would be nicer to just compare pgdump
assert ep.safe_psql("select count(*), sum(data::bigint)::bigint from t") == [
(expect_nrows, expect_sum)
]
validate_vanilla_equivalence(vanilla_pg)
vanilla_pg.stop()
#
# We have a Postgres data directory now.
# Make a localfs remote storage that looks like how after `fast_import` ran.
# TODO: actually exercise fast_import here
# TODO: test s3 remote storage
#
importbucket = neon_env_builder.repo_dir / "importbucket"
importbucket.mkdir()
# what cplane writes before scheduling fast_import
specpath = importbucket / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(importbucket / "pgdata")
statusdir = importbucket / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
#
# Do the import
#
tenant_id = TenantId.generate()
env.storage_controller.tenant_create(
tenant_id, shard_count=shard_count, shard_stripe_size=stripe_size
)
timeline_id = TimelineId.generate()
log.info("starting import")
start = time.monotonic()
idempotency = ImportPgdataIdemptencyKey.random()
log.info(f"idempotency key {idempotency}")
# TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop
# and check for 429
import_branch_name = "imported"
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket.absolute())}},
},
},
)
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
while True:
locations = env.storage_controller.locate(tenant_id)
active_count = 0
for location in locations:
shard_id = TenantShardId.parse(location["shard_id"])
ps = env.get_pageserver(location["node_id"])
try:
detail = ps.http_client().timeline_detail(shard_id, timeline_id)
state = detail["state"]
log.info(f"shard {shard_id} state: {state}")
if state == "Active":
active_count += 1
except PageserverApiException as e:
if e.status_code == 404:
log.info("not found, import is in progress")
continue
elif e.status_code == 429:
log.info("import is in progress")
continue
else:
raise
shard_status_file = statusdir / f"shard-{shard_id.shard_index}"
if state == "Active":
shard_status_file_contents = (
shard_status_file.read_text()
) # Active state implies import is done
shard_status = json.loads(shard_status_file_contents)
assert shard_status["done"] is True
if active_count == len(locations):
log.info("all shards are active")
break
time.sleep(1)
import_duration = time.monotonic() - start
log.info(f"import complete; duration={import_duration:.2f}s")
#
# Get some timeline details for later.
#
locations = env.storage_controller.locate(tenant_id)
[shard_zero] = [
loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0
]
shard_zero_ps = env.get_pageserver(shard_zero["node_id"])
shard_zero_http = shard_zero_ps.http_client()
shard_zero_timeline_info = shard_zero_http.timeline_detail(shard_zero["shard_id"], timeline_id)
initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"])
latest_gc_cutoff_lsn = Lsn(shard_zero_timeline_info["latest_gc_cutoff_lsn"])
last_record_lsn = Lsn(shard_zero_timeline_info["last_record_lsn"])
disk_consistent_lsn = Lsn(shard_zero_timeline_info["disk_consistent_lsn"])
_remote_consistent_lsn = Lsn(shard_zero_timeline_info["remote_consistent_lsn"])
remote_consistent_lsn_visible = Lsn(shard_zero_timeline_info["remote_consistent_lsn_visible"])
# assert remote_consistent_lsn_visible == remote_consistent_lsn TODO: this fails initially and after restart, presumably because `UploadQueue::clean.1` is still `None`
assert remote_consistent_lsn_visible == disk_consistent_lsn
assert initdb_lsn == latest_gc_cutoff_lsn
assert disk_consistent_lsn == initdb_lsn + 8
assert last_record_lsn == disk_consistent_lsn
# TODO: assert these values are the same everywhere
#
# Validate the resulting remote storage state.
#
#
# Validate the imported data
#
ro_endpoint = env.endpoints.create_start(
branch_name=import_branch_name, endpoint_id="ro", tenant_id=tenant_id, lsn=last_record_lsn
)
validate_vanilla_equivalence(ro_endpoint)
# ensure the import survives restarts
ro_endpoint.stop()
env.pageserver.stop(immediate=True)
env.pageserver.start()
ro_endpoint.start()
validate_vanilla_equivalence(ro_endpoint)
#
# validate the layer files in each shard only have the shard-specific data
# (the implementation would be functional but not efficient without this characteristic)
#
shards = env.storage_controller.locate(tenant_id)
for shard in shards:
shard_ps = env.get_pageserver(shard["node_id"])
result = shard_ps.timeline_scan_no_disposable_keys(shard["shard_id"], timeline_id)
assert result.tally.disposable_count == 0
assert (
result.tally.not_disposable_count > 0
), "sanity check, each shard should have some data"
#
# validate that we can write
#
rw_endpoint = env.endpoints.create_start(
branch_name=import_branch_name, endpoint_id="rw", tenant_id=tenant_id
)
rw_endpoint.safe_psql("create table othertable(values text)")
rw_lsn = Lsn(rw_endpoint.safe_psql_scalar("select pg_current_wal_flush_lsn()"))
# TODO: consider using `class Workload` here
# to do compaction and whatnot?
#
# validate that we can branch (important use case)
#
# ... at the tip
_ = env.create_branch(
new_branch_name="br-tip",
ancestor_branch_name=import_branch_name,
tenant_id=tenant_id,
ancestor_start_lsn=rw_lsn,
)
br_tip_endpoint = env.endpoints.create_start(
branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id
)
validate_vanilla_equivalence(br_tip_endpoint)
br_tip_endpoint.safe_psql("select * from othertable")
# ... at the initdb lsn
_ = env.create_branch(
new_branch_name="br-initdb",
ancestor_branch_name=import_branch_name,
tenant_id=tenant_id,
ancestor_start_lsn=initdb_lsn,
)
br_initdb_endpoint = env.endpoints.create_start(
branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id
)
validate_vanilla_equivalence(br_initdb_endpoint)
with pytest.raises(psycopg2.errors.UndefinedTable):
br_initdb_endpoint.safe_psql("select * from othertable")