diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 5b008f8182..c8440afb64 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -58,6 +58,8 @@ struct Args { pg_bin_dir: Utf8PathBuf, #[clap(long)] pg_lib_dir: Utf8PathBuf, + #[clap(long)] + pg_port: Option, // port to run postgres on, 5432 is default } #[serde_with::serde_as] @@ -74,6 +76,13 @@ enum EncryptionSecret { KMS { key_id: String }, } +// copied from pageserver_api::config::defaults::DEFAULT_LOCALE to avoid dependency just for a constant +const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") { + "C" +} else { + "C.UTF-8" +}; + #[tokio::main] pub(crate) async fn main() -> anyhow::Result<()> { utils::logging::init( @@ -97,6 +106,10 @@ pub(crate) async fn main() -> anyhow::Result<()> { let working_directory = args.working_directory; let pg_bin_dir = args.pg_bin_dir; let pg_lib_dir = args.pg_lib_dir; + let pg_port = args.pg_port.unwrap_or_else(|| { + info!("pg_port not specified, using default 5432"); + 5432 + }); // Initialize AWS clients only if s3_prefix is specified let (aws_config, kms_client) = if args.s3_prefix.is_some() { @@ -180,7 +193,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { superuser, - locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded, + locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, pg_version, initdb_bin: pg_bin_dir.join("initdb").as_ref(), library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. @@ -197,6 +210,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { let mut postgres_proc = tokio::process::Command::new(pgbin) .arg("-D") .arg(&pgdata_dir) + .args(["-p", &format!("{pg_port}")]) .args(["-c", "wal_level=minimal"]) .args(["-c", "shared_buffers=10GB"]) .args(["-c", "max_wal_senders=0"]) @@ -216,6 +230,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { ), ]) .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .spawn() @@ -232,7 +247,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { // Create neondb database in the running postgres let restore_pg_connstring = - format!("host=localhost port=5432 user={superuser} dbname=postgres"); + format!("host=localhost port={pg_port} user={superuser} dbname=postgres"); let start_time = std::time::Instant::now(); @@ -314,6 +329,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { .arg(&source_connection_string) // how we run it .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir) .kill_on_drop(true) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) @@ -347,6 +363,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { .arg(&dumpdir) // how we run it .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir) .kill_on_drop(true) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) diff --git a/test_runner/conftest.py b/test_runner/conftest.py index 9e32469d69..4b591d3316 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -15,4 +15,5 @@ pytest_plugins = ( "fixtures.compare_fixtures", "fixtures.slow", "fixtures.reruns", + "fixtures.fast_import", ) diff --git a/test_runner/fixtures/fast_import.py b/test_runner/fixtures/fast_import.py new file mode 100644 index 0000000000..33248132ab --- /dev/null +++ b/test_runner/fixtures/fast_import.py @@ -0,0 +1,104 @@ +import os +import shutil +import subprocess +import tempfile +from collections.abc import Iterator +from pathlib import Path + +import pytest + +from fixtures.log_helper import log +from fixtures.neon_cli import AbstractNeonCli +from fixtures.pg_version import PgVersion + + +class FastImport(AbstractNeonCli): + COMMAND = "fast_import" + cmd: subprocess.CompletedProcess[str] | None = None + + def __init__( + self, + extra_env: dict[str, str] | None, + binpath: Path, + pg_distrib_dir: Path, + pg_version: PgVersion, + workdir: Path, + ): + if extra_env is None: + env_vars = {} + else: + env_vars = extra_env.copy() + + if not (binpath / self.COMMAND).exists(): + raise Exception(f"{self.COMMAND} binary not found at '{binpath}'") + super().__init__(env_vars, binpath) + + pg_dir = pg_distrib_dir / pg_version.v_prefixed + self.pg_distrib_dir = pg_distrib_dir + self.pg_version = pg_version + self.pg_bin = pg_dir / "bin" + if not (self.pg_bin / "postgres").exists(): + raise Exception(f"postgres binary was not found at '{self.pg_bin}'") + self.pg_lib = pg_dir / "lib" + if env_vars.get("LD_LIBRARY_PATH") is not None: + self.pg_lib = Path(env_vars["LD_LIBRARY_PATH"]) + elif os.getenv("LD_LIBRARY_PATH") is not None: + self.pg_lib = Path(str(os.getenv("LD_LIBRARY_PATH"))) + if not workdir.exists(): + raise Exception(f"Working directory '{workdir}' does not exist") + self.workdir = workdir + + def run( + self, + pg_port: int, + source_connection_string: str | None = None, + s3prefix: str | None = None, + interactive: bool = False, + ) -> subprocess.CompletedProcess[str]: + if self.cmd is not None: + raise Exception("Command already executed") + args = [ + f"--pg-bin-dir={self.pg_bin}", + f"--pg-lib-dir={self.pg_lib}", + f"--pg-port={pg_port}", + f"--working-directory={self.workdir}", + ] + if source_connection_string is not None: + args.append(f"--source-connection-string={source_connection_string}") + if s3prefix is not None: + args.append(f"--s3-prefix={s3prefix}") + if interactive: + args.append("--interactive") + + self.cmd = self.raw_cli(args) + return self.cmd + + def __enter__(self): + return self + + def __exit__(self, *args): + if self.workdir.exists(): + shutil.rmtree(self.workdir) + + +@pytest.fixture(scope="function") +def fast_import( + pg_version: PgVersion, + test_output_dir: Path, + neon_binpath: Path, + pg_distrib_dir: Path, +) -> Iterator[FastImport]: + workdir = Path(tempfile.mkdtemp()) + with FastImport(None, neon_binpath, pg_distrib_dir, pg_version, workdir) as fi: + yield fi + + if fi.cmd is None: + return + + # dump stdout & stderr into test log dir + with open(test_output_dir / "fast_import.stdout", "w") as f: + f.write(fi.cmd.stdout) + with open(test_output_dir / "fast_import.stderr", "w") as f: + f.write(fi.cmd.stderr) + + log.info("Written logs to %s", test_output_dir) diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 6ea2393a9d..d02a9d19db 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -7,13 +7,15 @@ import psycopg2 import psycopg2.errors import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId +from fixtures.fast_import import FastImport from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder, VanillaPostgres +from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres from fixtures.pageserver.http import ( ImportPgdataIdemptencyKey, PageserverApiException, ) from fixtures.pg_version import PgVersion +from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import RemoteStorageKind from fixtures.utils import run_only_on_postgres from pytest_httpserver import HTTPServer @@ -313,3 +315,41 @@ def test_pgdata_import_smoke( validate_vanilla_equivalence(br_initdb_endpoint) with pytest.raises(psycopg2.errors.UndefinedTable): br_initdb_endpoint.safe_psql("select * from othertable") + + +@run_only_on_postgres( + [PgVersion.V14, PgVersion.V15, PgVersion.V16], + "newer control file catalog version and struct format isn't supported", +) +def test_fast_import_binary( + test_output_dir, + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + fast_import: FastImport, +): + vanilla_pg.start() + vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") + + pg_port = port_distributor.get_port() + fast_import.run(pg_port, vanilla_pg.connstr()) + vanilla_pg.stop() + + pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) + with VanillaPostgres( + fast_import.workdir / "pgdata", pgbin, pg_port, False + ) as new_pgdata_vanilla_pg: + new_pgdata_vanilla_pg.start() + + # database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres + conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb") + res = conn.safe_psql("SELECT count(*) FROM foo;") + log.info(f"Result: {res}") + assert res[0][0] == 10 + + +# TODO: Maybe test with pageserver? +# 1. run whole neon env +# 2. create timeline with some s3 path??? +# 3. run fast_import with s3 prefix +# 4. ??? mock http where pageserver will report progress +# 5. run compute on this timeline and check if data is there