From 3d7a32f6196e87b00491fcdc4887ec9ed1bd1640 Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Fri, 14 Feb 2025 16:10:06 +0000 Subject: [PATCH] fast import: allow restore to provided connection string (#10407) Within https://github.com/neondatabase/cloud/issues/22089 we decided that would be nice to start with import that runs dump-restore into a running compute (more on this [here](https://www.notion.so/neondatabase/2024-Jan-13-Migration-Assistant-Next-Steps-Proposal-Revised-17af189e004780228bdbcad13eeda93f?pvs=4#17af189e004780de816ccd9c13afd953)) We could do it by writing another tool or by extending existing `fast_import.rs`, we chose the latter. In this PR, I have added optional `restore_connection_string` as a cli arg and as a part of the json spec. If specified, the script will not run postgres and will just perform restore into provided connection string. TODO: - [x] fast_import.rs: - [x] cli arg in the fast_import.rs - [x] encoded connstring in json spec - [x] simplify `fn main` a little, take out too verbose stuff to some functions - [ ] ~~allow streaming from dump stdout to restore stdin~~ will do in a separate PR - [ ] ~~address https://github.com/neondatabase/neon/pull/10251#pullrequestreview-2551877845~~ will do in a separate PR - [x] tests: - [x] restore with cli arg in the fast_import.rs - [x] restore with encoded connstring in json spec in s3 - [ ] ~~test with custom dbname~~ will do in a separate PR - [ ] ~~test with s3 + pageserver + fast import binary~~ https://github.com/neondatabase/neon/pull/10487 - [ ] ~~https://github.com/neondatabase/neon/pull/10271#discussion_r1923715493~~ will do in a separate PR neondatabase/cloud#22775 --------- Co-authored-by: Eduard Dykman --- compute_tools/src/bin/fast_import.rs | 656 ++++++++++++++-------- poetry.lock | 15 +- pyproject.toml | 2 +- test_runner/fixtures/fast_import.py | 62 +- test_runner/fixtures/neon_fixtures.py | 27 + test_runner/regress/test_import_pgdata.py | 346 +++++++++++- 6 files changed, 866 insertions(+), 242 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 27cf1c2317..dad15d67b7 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -25,10 +25,10 @@ //! docker push localhost:3030/localregistry/compute-node-v14:latest //! ``` -use anyhow::Context; +use anyhow::{bail, Context}; use aws_config::BehaviorVersion; use camino::{Utf8Path, Utf8PathBuf}; -use clap::Parser; +use clap::{Parser, Subcommand}; use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion}; use nix::unistd::Pid; use tracing::{error, info, info_span, warn, Instrument}; @@ -44,32 +44,59 @@ mod s3_uri; const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600); const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300); +#[derive(Subcommand, Debug)] +enum Command { + /// Runs local postgres (neon binary), restores into it, + /// uploads pgdata to s3 to be consumed by pageservers + Pgdata { + /// Raw connection string to the source database. Used only in tests, + /// real scenario uses encrypted connection string in spec.json from s3. + #[clap(long)] + source_connection_string: Option, + /// If specified, will not shut down the local postgres after the import. Used in local testing + #[clap(short, long)] + interactive: bool, + /// Port to run postgres on. Default is 5432. + #[clap(long, default_value_t = 5432)] + pg_port: u16, // port to run postgres on, 5432 is default + + /// Number of CPUs in the system. This is used to configure # of + /// parallel worker processes, for index creation. + #[clap(long, env = "NEON_IMPORTER_NUM_CPUS")] + num_cpus: Option, + + /// Amount of RAM in the system. This is used to configure shared_buffers + /// and maintenance_work_mem. + #[clap(long, env = "NEON_IMPORTER_MEMORY_MB")] + memory_mb: Option, + }, + + /// Runs pg_dump-pg_restore from source to destination without running local postgres. + DumpRestore { + /// Raw connection string to the source database. Used only in tests, + /// real scenario uses encrypted connection string in spec.json from s3. + #[clap(long)] + source_connection_string: Option, + /// Raw connection string to the destination database. Used only in tests, + /// real scenario uses encrypted connection string in spec.json from s3. + #[clap(long)] + destination_connection_string: Option, + }, +} + #[derive(clap::Parser)] struct Args { - #[clap(long)] + #[clap(long, env = "NEON_IMPORTER_WORKDIR")] working_directory: Utf8PathBuf, #[clap(long, env = "NEON_IMPORTER_S3_PREFIX")] s3_prefix: Option, - #[clap(long)] - source_connection_string: Option, - #[clap(short, long)] - interactive: bool, - #[clap(long)] + #[clap(long, env = "NEON_IMPORTER_PG_BIN_DIR")] pg_bin_dir: Utf8PathBuf, - #[clap(long)] + #[clap(long, env = "NEON_IMPORTER_PG_LIB_DIR")] pg_lib_dir: Utf8PathBuf, - #[clap(long)] - pg_port: Option, // port to run postgres on, 5432 is default - /// Number of CPUs in the system. This is used to configure # of - /// parallel worker processes, for index creation. - #[clap(long, env = "NEON_IMPORTER_NUM_CPUS")] - num_cpus: Option, - - /// Amount of RAM in the system. This is used to configure shared_buffers - /// and maintenance_work_mem. - #[clap(long, env = "NEON_IMPORTER_MEMORY_MB")] - memory_mb: Option, + #[clap(subcommand)] + command: Command, } #[serde_with::serde_as] @@ -78,6 +105,8 @@ struct Spec { encryption_secret: EncryptionSecret, #[serde_as(as = "serde_with::base64::Base64")] source_connstring_ciphertext_base64: Vec, + #[serde_as(as = "Option")] + destination_connstring_ciphertext_base64: Option>, } #[derive(serde::Deserialize)] @@ -93,192 +122,151 @@ const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") { "C.UTF-8" }; -#[tokio::main] -pub(crate) async fn main() -> anyhow::Result<()> { - utils::logging::init( - utils::logging::LogFormat::Plain, - utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter, - utils::logging::Output::Stdout, - )?; - - info!("starting"); - - let args = Args::parse(); - - // Validate arguments - if args.s3_prefix.is_none() && args.source_connection_string.is_none() { - anyhow::bail!("either s3_prefix or source_connection_string must be specified"); - } - if args.s3_prefix.is_some() && args.source_connection_string.is_some() { - anyhow::bail!("only one of s3_prefix or source_connection_string can be specified"); - } - - let working_directory = args.working_directory; - let pg_bin_dir = args.pg_bin_dir; - let pg_lib_dir = args.pg_lib_dir; - let pg_port = args.pg_port.unwrap_or_else(|| { - info!("pg_port not specified, using default 5432"); - 5432 - }); - - // Initialize AWS clients only if s3_prefix is specified - let (aws_config, kms_client) = if args.s3_prefix.is_some() { - let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; - let kms = aws_sdk_kms::Client::new(&config); - (Some(config), Some(kms)) - } else { - (None, None) - }; - - // Get source connection string either from S3 spec or direct argument - let source_connection_string = if let Some(s3_prefix) = &args.s3_prefix { - let spec: Spec = { - let spec_key = s3_prefix.append("/spec.json"); - let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap()); - let object = s3_client - .get_object() - .bucket(&spec_key.bucket) - .key(spec_key.key) - .send() - .await - .context("get spec from s3")? - .body - .collect() - .await - .context("download spec body")?; - serde_json::from_slice(&object.into_bytes()).context("parse spec as json")? - }; - - match spec.encryption_secret { - EncryptionSecret::KMS { key_id } => { - let mut output = kms_client - .unwrap() - .decrypt() - .key_id(key_id) - .ciphertext_blob(aws_sdk_s3::primitives::Blob::new( - spec.source_connstring_ciphertext_base64, - )) - .send() - .await - .context("decrypt source connection string")?; - let plaintext = output - .plaintext - .take() - .context("get plaintext source connection string")?; - String::from_utf8(plaintext.into_inner()) - .context("parse source connection string as utf8")? - } - } - } else { - args.source_connection_string.unwrap() - }; - - match tokio::fs::create_dir(&working_directory).await { - Ok(()) => {} - Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { - if !is_directory_empty(&working_directory) - .await - .context("check if working directory is empty")? - { - anyhow::bail!("working directory is not empty"); - } else { - // ok - } - } - Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), - } - - let pgdata_dir = working_directory.join("pgdata"); - tokio::fs::create_dir(&pgdata_dir) +async fn decode_connstring( + kms_client: &aws_sdk_kms::Client, + key_id: &String, + connstring_ciphertext_base64: Vec, +) -> Result { + let mut output = kms_client + .decrypt() + .key_id(key_id) + .ciphertext_blob(aws_sdk_s3::primitives::Blob::new( + connstring_ciphertext_base64, + )) + .send() .await - .context("create pgdata directory")?; + .context("decrypt connection string")?; - let pgbin = pg_bin_dir.join("postgres"); - let pg_version = match get_pg_version(pgbin.as_ref()) { - PostgresMajorVersion::V14 => 14, - PostgresMajorVersion::V15 => 15, - PostgresMajorVersion::V16 => 16, - PostgresMajorVersion::V17 => 17, - }; - let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded - postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { - superuser, - locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, - pg_version, - initdb_bin: pg_bin_dir.join("initdb").as_ref(), - library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. - pgdata: &pgdata_dir, - }) - .await - .context("initdb")?; + let plaintext = output + .plaintext + .take() + .context("get plaintext connection string")?; - // If the caller didn't specify CPU / RAM to use for sizing, default to - // number of CPUs in the system, and pretty arbitrarily, 256 MB of RAM. - let nproc = args.num_cpus.unwrap_or_else(num_cpus::get); - let memory_mb = args.memory_mb.unwrap_or(256); + String::from_utf8(plaintext.into_inner()).context("parse connection string as utf8") +} - // Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for - // maintenance_work_mem (i.e. for sorting during index creation), and leave the rest - // available for misc other stuff that PostgreSQL uses memory for. - let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize; - let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize; +struct PostgresProcess { + pgdata_dir: Utf8PathBuf, + pg_bin_dir: Utf8PathBuf, + pgbin: Utf8PathBuf, + pg_lib_dir: Utf8PathBuf, + postgres_proc: Option, +} - // - // Launch postgres process - // - let mut postgres_proc = tokio::process::Command::new(pgbin) - .arg("-D") - .arg(&pgdata_dir) - .args(["-p", &format!("{pg_port}")]) - .args(["-c", "wal_level=minimal"]) - .args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")]) - .args(["-c", "max_wal_senders=0"]) - .args(["-c", "fsync=off"]) - .args(["-c", "full_page_writes=off"]) - .args(["-c", "synchronous_commit=off"]) - .args([ - "-c", - &format!("maintenance_work_mem={maintenance_work_mem_mb}MB"), - ]) - .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")]) - .args(["-c", &format!("max_parallel_workers={nproc}")]) - .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")]) - .args(["-c", &format!("max_worker_processes={nproc}")]) - .args([ - "-c", - &format!( - "effective_io_concurrency={}", - if cfg!(target_os = "macos") { 0 } else { 100 } - ), - ]) - .env_clear() - .env("LD_LIBRARY_PATH", &pg_lib_dir) - .env( - "ASAN_OPTIONS", - std::env::var("ASAN_OPTIONS").unwrap_or_default(), +impl PostgresProcess { + fn new(pgdata_dir: Utf8PathBuf, pg_bin_dir: Utf8PathBuf, pg_lib_dir: Utf8PathBuf) -> Self { + Self { + pgdata_dir, + pgbin: pg_bin_dir.join("postgres"), + pg_bin_dir, + pg_lib_dir, + postgres_proc: None, + } + } + + async fn prepare(&self, initdb_user: &str) -> Result<(), anyhow::Error> { + tokio::fs::create_dir(&self.pgdata_dir) + .await + .context("create pgdata directory")?; + + let pg_version = match get_pg_version(self.pgbin.as_ref()) { + PostgresMajorVersion::V14 => 14, + PostgresMajorVersion::V15 => 15, + PostgresMajorVersion::V16 => 16, + PostgresMajorVersion::V17 => 17, + }; + postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { + superuser: initdb_user, + locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, + pg_version, + initdb_bin: self.pg_bin_dir.join("initdb").as_ref(), + library_search_path: &self.pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. + pgdata: &self.pgdata_dir, + }) + .await + .context("initdb") + } + + async fn start( + &mut self, + initdb_user: &str, + port: u16, + nproc: usize, + memory_mb: usize, + ) -> Result<&tokio::process::Child, anyhow::Error> { + self.prepare(initdb_user).await?; + + // Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for + // maintenance_work_mem (i.e. for sorting during index creation), and leave the rest + // available for misc other stuff that PostgreSQL uses memory for. + let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize; + let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize; + + // + // Launch postgres process + // + let mut proc = tokio::process::Command::new(&self.pgbin) + .arg("-D") + .arg(&self.pgdata_dir) + .args(["-p", &format!("{port}")]) + .args(["-c", "wal_level=minimal"]) + .args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")]) + .args(["-c", "shared_buffers=10GB"]) + .args(["-c", "max_wal_senders=0"]) + .args(["-c", "fsync=off"]) + .args(["-c", "full_page_writes=off"]) + .args(["-c", "synchronous_commit=off"]) + .args([ + "-c", + &format!("maintenance_work_mem={maintenance_work_mem_mb}MB"), + ]) + .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")]) + .args(["-c", &format!("max_worker_processes={nproc}")]) + .args(["-c", "effective_io_concurrency=100"]) + .env_clear() + .env("LD_LIBRARY_PATH", &self.pg_lib_dir) + .env( + "ASAN_OPTIONS", + std::env::var("ASAN_OPTIONS").unwrap_or_default(), + ) + .env( + "UBSAN_OPTIONS", + std::env::var("UBSAN_OPTIONS").unwrap_or_default(), + ) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .context("spawn postgres")?; + + info!("spawned postgres, waiting for it to become ready"); + tokio::spawn( + child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take()) + .instrument(info_span!("postgres")), + ); + + self.postgres_proc = Some(proc); + Ok(self.postgres_proc.as_ref().unwrap()) + } + + async fn shutdown(&mut self) -> Result<(), anyhow::Error> { + let proc: &mut tokio::process::Child = self.postgres_proc.as_mut().unwrap(); + info!("shutdown postgres"); + nix::sys::signal::kill( + Pid::from_raw(i32::try_from(proc.id().unwrap()).expect("convert child pid to i32")), + nix::sys::signal::SIGTERM, ) - .env( - "UBSAN_OPTIONS", - std::env::var("UBSAN_OPTIONS").unwrap_or_default(), - ) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .spawn() - .context("spawn postgres")?; - - info!("spawned postgres, waiting for it to become ready"); - tokio::spawn( - child_stdio_to_log::relay_process_output( - postgres_proc.stdout.take(), - postgres_proc.stderr.take(), - ) - .instrument(info_span!("postgres")), - ); + .context("signal postgres to shut down")?; + proc.wait() + .await + .context("wait for postgres to shut down") + .map(|_| ()) + } +} +async fn wait_until_ready(connstring: String, create_dbname: String) { // Create neondb database in the running postgres - let restore_pg_connstring = - format!("host=localhost port={pg_port} user={superuser} dbname=postgres"); - let start_time = std::time::Instant::now(); loop { @@ -289,7 +277,12 @@ pub(crate) async fn main() -> anyhow::Result<()> { std::process::exit(1); } - match tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await { + match tokio_postgres::connect( + &connstring.replace("dbname=neondb", "dbname=postgres"), + tokio_postgres::NoTls, + ) + .await + { Ok((client, connection)) => { // Spawn the connection handling task to maintain the connection tokio::spawn(async move { @@ -298,9 +291,12 @@ pub(crate) async fn main() -> anyhow::Result<()> { } }); - match client.simple_query("CREATE DATABASE neondb;").await { + match client + .simple_query(format!("CREATE DATABASE {create_dbname};").as_str()) + .await + { Ok(_) => { - info!("created neondb database"); + info!("created {} database", create_dbname); break; } Err(e) => { @@ -324,10 +320,16 @@ pub(crate) async fn main() -> anyhow::Result<()> { } } } +} - let restore_pg_connstring = restore_pg_connstring.replace("dbname=postgres", "dbname=neondb"); - - let dumpdir = working_directory.join("dumpdir"); +async fn run_dump_restore( + workdir: Utf8PathBuf, + pg_bin_dir: Utf8PathBuf, + pg_lib_dir: Utf8PathBuf, + source_connstring: String, + destination_connstring: String, +) -> Result<(), anyhow::Error> { + let dumpdir = workdir.join("dumpdir"); let common_args = [ // schema mapping (prob suffices to specify them on one side) @@ -356,7 +358,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { .arg("--no-sync") // POSITIONAL args // source db (db name included in connection string) - .arg(&source_connection_string) + .arg(&source_connstring) // how we run it .env_clear() .env("LD_LIBRARY_PATH", &pg_lib_dir) @@ -376,19 +378,18 @@ pub(crate) async fn main() -> anyhow::Result<()> { let st = pg_dump.wait().await.context("wait for pg_dump")?; info!(status=?st, "pg_dump exited"); if !st.success() { - warn!(status=%st, "pg_dump failed, restore will likely fail as well"); + error!(status=%st, "pg_dump failed, restore will likely fail as well"); + bail!("pg_dump failed"); } } - // TODO: do it in a streaming way, plenty of internal research done on this already + // TODO: maybe do it in a streaming way, plenty of internal research done on this already // TODO: do the unlogged table trick - - info!("restore from working directory into vanilla postgres"); { let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore")) .args(&common_args) .arg("-d") - .arg(&restore_pg_connstring) + .arg(&destination_connstring) // POSITIONAL args .arg(&dumpdir) // how we run it @@ -411,33 +412,82 @@ pub(crate) async fn main() -> anyhow::Result<()> { let st = pg_restore.wait().await.context("wait for pg_restore")?; info!(status=?st, "pg_restore exited"); if !st.success() { - warn!(status=%st, "pg_restore failed, restore will likely fail as well"); + error!(status=%st, "pg_restore failed, restore will likely fail as well"); + bail!("pg_restore failed"); } } + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn cmd_pgdata( + kms_client: Option, + maybe_s3_prefix: Option, + maybe_spec: Option, + source_connection_string: Option, + interactive: bool, + pg_port: u16, + workdir: Utf8PathBuf, + pg_bin_dir: Utf8PathBuf, + pg_lib_dir: Utf8PathBuf, + num_cpus: Option, + memory_mb: Option, +) -> Result<(), anyhow::Error> { + if maybe_spec.is_none() && source_connection_string.is_none() { + bail!("spec must be provided for pgdata command"); + } + if maybe_spec.is_some() && source_connection_string.is_some() { + bail!("only one of spec or source_connection_string can be provided"); + } + + let source_connection_string = if let Some(spec) = maybe_spec { + match spec.encryption_secret { + EncryptionSecret::KMS { key_id } => { + decode_connstring( + kms_client.as_ref().unwrap(), + &key_id, + spec.source_connstring_ciphertext_base64, + ) + .await? + } + } + } else { + source_connection_string.unwrap() + }; + + let superuser = "cloud_admin"; + let destination_connstring = format!( + "host=localhost port={} user={} dbname=neondb", + pg_port, superuser + ); + + let pgdata_dir = workdir.join("pgdata"); + let mut proc = PostgresProcess::new(pgdata_dir.clone(), pg_bin_dir.clone(), pg_lib_dir.clone()); + let nproc = num_cpus.unwrap_or_else(num_cpus::get); + let memory_mb = memory_mb.unwrap_or(256); + proc.start(superuser, pg_port, nproc, memory_mb).await?; + wait_until_ready(destination_connstring.clone(), "neondb".to_string()).await; + + run_dump_restore( + workdir.clone(), + pg_bin_dir, + pg_lib_dir, + source_connection_string, + destination_connstring, + ) + .await?; + // If interactive mode, wait for Ctrl+C - if args.interactive { + if interactive { info!("Running in interactive mode. Press Ctrl+C to shut down."); tokio::signal::ctrl_c().await.context("wait for ctrl-c")?; } - info!("shutdown postgres"); - { - nix::sys::signal::kill( - Pid::from_raw( - i32::try_from(postgres_proc.id().unwrap()).expect("convert child pid to i32"), - ), - nix::sys::signal::SIGTERM, - ) - .context("signal postgres to shut down")?; - postgres_proc - .wait() - .await - .context("wait for postgres to shut down")?; - } + proc.shutdown().await?; // Only sync if s3_prefix was specified - if let Some(s3_prefix) = args.s3_prefix { + if let Some(s3_prefix) = maybe_s3_prefix { info!("upload pgdata"); aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/")) .await @@ -445,7 +495,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { info!("write status"); { - let status_dir = working_directory.join("status"); + let status_dir = workdir.join("status"); std::fs::create_dir(&status_dir).context("create status directory")?; let status_file = status_dir.join("pgdata"); std::fs::write(&status_file, serde_json::json!({"done": true}).to_string()) @@ -458,3 +508,153 @@ pub(crate) async fn main() -> anyhow::Result<()> { Ok(()) } + +async fn cmd_dumprestore( + kms_client: Option, + maybe_spec: Option, + source_connection_string: Option, + destination_connection_string: Option, + workdir: Utf8PathBuf, + pg_bin_dir: Utf8PathBuf, + pg_lib_dir: Utf8PathBuf, +) -> Result<(), anyhow::Error> { + let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec { + match spec.encryption_secret { + EncryptionSecret::KMS { key_id } => { + let source = decode_connstring( + kms_client.as_ref().unwrap(), + &key_id, + spec.source_connstring_ciphertext_base64, + ) + .await?; + + let dest = if let Some(dest_ciphertext) = + spec.destination_connstring_ciphertext_base64 + { + decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext) + .await? + } else { + bail!("destination connection string must be provided in spec for dump_restore command"); + }; + + (source, dest) + } + } + } else { + ( + source_connection_string.unwrap(), + if let Some(val) = destination_connection_string { + val + } else { + bail!("destination connection string must be provided for dump_restore command"); + }, + ) + }; + + run_dump_restore( + workdir, + pg_bin_dir, + pg_lib_dir, + source_connstring, + destination_connstring, + ) + .await +} + +#[tokio::main] +pub(crate) async fn main() -> anyhow::Result<()> { + utils::logging::init( + utils::logging::LogFormat::Json, + utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter, + utils::logging::Output::Stdout, + )?; + + info!("starting"); + + let args = Args::parse(); + + // Initialize AWS clients only if s3_prefix is specified + let (aws_config, kms_client) = if args.s3_prefix.is_some() { + let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + let kms = aws_sdk_kms::Client::new(&config); + (Some(config), Some(kms)) + } else { + (None, None) + }; + + let spec: Option = if let Some(s3_prefix) = &args.s3_prefix { + let spec_key = s3_prefix.append("/spec.json"); + let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap()); + let object = s3_client + .get_object() + .bucket(&spec_key.bucket) + .key(spec_key.key) + .send() + .await + .context("get spec from s3")? + .body + .collect() + .await + .context("download spec body")?; + serde_json::from_slice(&object.into_bytes()).context("parse spec as json")? + } else { + None + }; + + match tokio::fs::create_dir(&args.working_directory).await { + Ok(()) => {} + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + if !is_directory_empty(&args.working_directory) + .await + .context("check if working directory is empty")? + { + bail!("working directory is not empty"); + } else { + // ok + } + } + Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), + } + + match args.command { + Command::Pgdata { + source_connection_string, + interactive, + pg_port, + num_cpus, + memory_mb, + } => { + cmd_pgdata( + kms_client, + args.s3_prefix, + spec, + source_connection_string, + interactive, + pg_port, + args.working_directory, + args.pg_bin_dir, + args.pg_lib_dir, + num_cpus, + memory_mb, + ) + .await?; + } + Command::DumpRestore { + source_connection_string, + destination_connection_string, + } => { + cmd_dumprestore( + kms_client, + spec, + source_connection_string, + destination_connection_string, + args.working_directory, + args.pg_bin_dir, + args.pg_lib_dir, + ) + .await?; + } + } + + Ok(()) +} diff --git a/poetry.lock b/poetry.lock index fd200159b9..e2c71ca012 100644 --- a/poetry.lock +++ b/poetry.lock @@ -412,6 +412,7 @@ files = [ [package.dependencies] botocore-stubs = "*" +mypy-boto3-kms = {version = ">=1.26.0,<1.27.0", optional = true, markers = "extra == \"kms\""} mypy-boto3-s3 = {version = ">=1.26.0,<1.27.0", optional = true, markers = "extra == \"s3\""} types-s3transfer = "*" typing-extensions = ">=4.1.0" @@ -2022,6 +2023,18 @@ install-types = ["pip"] mypyc = ["setuptools (>=50)"] reports = ["lxml"] +[[package]] +name = "mypy-boto3-kms" +version = "1.26.147" +description = "Type annotations for boto3.KMS 1.26.147 service generated with mypy-boto3-builder 7.14.5" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "mypy-boto3-kms-1.26.147.tar.gz", hash = "sha256:816a4d1bb0585e1b9620a3f96c1d69a06f53b7b5621858579dd77c60dbb5fa5c"}, + {file = "mypy_boto3_kms-1.26.147-py3-none-any.whl", hash = "sha256:493f0db674a25c88769f5cb8ab8ac00d3dda5dfc903d5cda34c990ee64689f79"}, +] + [[package]] name = "mypy-boto3-s3" version = "1.26.0.post1" @@ -3807,4 +3820,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "4dc3165fe22c0e0f7a030ea0f8a680ae2ff74561d8658c393abbe9112caaf5d7" +content-hash = "03697c0a4d438ef088b0d397b8f0570aa3998ccf833fe612400824792498878b" diff --git a/pyproject.toml b/pyproject.toml index e299c421e9..51cd68e002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ Jinja2 = "^3.1.5" types-requests = "^2.31.0.0" types-psycopg2 = "^2.9.21.20241019" boto3 = "^1.34.11" -boto3-stubs = {extras = ["s3"], version = "^1.26.16"} +boto3-stubs = {extras = ["s3", "kms"], version = "^1.26.16"} moto = {extras = ["server"], version = "^5.0.6"} backoff = "^2.2.1" pytest-lazy-fixture = "^0.6.3" diff --git a/test_runner/fixtures/fast_import.py b/test_runner/fixtures/fast_import.py index 33248132ab..d674be99de 100644 --- a/test_runner/fixtures/fast_import.py +++ b/test_runner/fixtures/fast_import.py @@ -4,8 +4,10 @@ import subprocess import tempfile from collections.abc import Iterator from pathlib import Path +from typing import cast import pytest +from _pytest.config import Config from fixtures.log_helper import log from fixtures.neon_cli import AbstractNeonCli @@ -23,6 +25,7 @@ class FastImport(AbstractNeonCli): pg_distrib_dir: Path, pg_version: PgVersion, workdir: Path, + cleanup: bool = True, ): if extra_env is None: env_vars = {} @@ -47,12 +50,43 @@ class FastImport(AbstractNeonCli): if not workdir.exists(): raise Exception(f"Working directory '{workdir}' does not exist") self.workdir = workdir + self.cleanup = cleanup + + def run_pgdata( + self, + s3prefix: str | None = None, + pg_port: int | None = None, + source_connection_string: str | None = None, + interactive: bool = False, + ): + return self.run( + "pgdata", + s3prefix=s3prefix, + pg_port=pg_port, + source_connection_string=source_connection_string, + interactive=interactive, + ) + + def run_dump_restore( + self, + s3prefix: str | None = None, + source_connection_string: str | None = None, + destination_connection_string: str | None = None, + ): + return self.run( + "dump-restore", + s3prefix=s3prefix, + source_connection_string=source_connection_string, + destination_connection_string=destination_connection_string, + ) def run( self, - pg_port: int, - source_connection_string: str | None = None, + command: str, s3prefix: str | None = None, + pg_port: int | None = None, + source_connection_string: str | None = None, + destination_connection_string: str | None = None, interactive: bool = False, ) -> subprocess.CompletedProcess[str]: if self.cmd is not None: @@ -60,13 +94,17 @@ class FastImport(AbstractNeonCli): args = [ f"--pg-bin-dir={self.pg_bin}", f"--pg-lib-dir={self.pg_lib}", - f"--pg-port={pg_port}", f"--working-directory={self.workdir}", ] - if source_connection_string is not None: - args.append(f"--source-connection-string={source_connection_string}") if s3prefix is not None: args.append(f"--s3-prefix={s3prefix}") + args.append(command) + if pg_port is not None: + args.append(f"--pg-port={pg_port}") + if source_connection_string is not None: + args.append(f"--source-connection-string={source_connection_string}") + if destination_connection_string is not None: + args.append(f"--destination-connection-string={destination_connection_string}") if interactive: args.append("--interactive") @@ -77,7 +115,7 @@ class FastImport(AbstractNeonCli): return self def __exit__(self, *args): - if self.workdir.exists(): + if self.workdir.exists() and self.cleanup: shutil.rmtree(self.workdir) @@ -87,9 +125,17 @@ def fast_import( test_output_dir: Path, neon_binpath: Path, pg_distrib_dir: Path, + pytestconfig: Config, ) -> Iterator[FastImport]: - workdir = Path(tempfile.mkdtemp()) - with FastImport(None, neon_binpath, pg_distrib_dir, pg_version, workdir) as fi: + workdir = Path(tempfile.mkdtemp(dir=test_output_dir, prefix="fast_import_")) + with FastImport( + None, + neon_binpath, + pg_distrib_dir, + pg_version, + workdir, + cleanup=not cast(bool, pytestconfig.getoption("--preserve-database-files")), + ) as fi: yield fi if fi.cmd is None: diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 469bc8a1e5..73607db7d8 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -27,6 +27,7 @@ from urllib.parse import quote, urlparse import asyncpg import backoff +import boto3 import httpx import psycopg2 import psycopg2.sql @@ -37,6 +38,8 @@ from _pytest.config import Config from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureRequest from jwcrypto import jwk +from mypy_boto3_kms import KMSClient +from mypy_boto3_s3 import S3Client # Type-related stuff from psycopg2.extensions import connection as PgConnection @@ -199,6 +202,30 @@ def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]: mock_s3_server.kill() +@pytest.fixture(scope="session") +def mock_kms(mock_s3_server: MockS3Server) -> Iterator[KMSClient]: + yield boto3.client( + "kms", + endpoint_url=mock_s3_server.endpoint(), + region_name=mock_s3_server.region(), + aws_access_key_id=mock_s3_server.access_key(), + aws_secret_access_key=mock_s3_server.secret_key(), + aws_session_token=mock_s3_server.session_token(), + ) + + +@pytest.fixture(scope="session") +def mock_s3_client(mock_s3_server: MockS3Server) -> Iterator[S3Client]: + yield boto3.client( + "s3", + endpoint_url=mock_s3_server.endpoint(), + region_name=mock_s3_server.region(), + aws_access_key_id=mock_s3_server.access_key(), + aws_secret_access_key=mock_s3_server.secret_key(), + aws_session_token=mock_s3_server.session_token(), + ) + + class PgProtocol: """Reusable connection logic""" diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index ea86eb62eb..71e0d16edd 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -1,7 +1,9 @@ +import base64 import json import re import time from enum import Enum +from pathlib import Path import psycopg2 import psycopg2.errors @@ -14,8 +16,12 @@ from fixtures.pageserver.http import ( ImportPgdataIdemptencyKey, PageserverApiException, ) +from fixtures.pg_version import PgVersion from fixtures.port_distributor import PortDistributor -from fixtures.remote_storage import RemoteStorageKind +from fixtures.remote_storage import MockS3Server, RemoteStorageKind +from mypy_boto3_kms import KMSClient +from mypy_boto3_kms.type_defs import EncryptResponseTypeDef +from mypy_boto3_s3 import S3Client from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response @@ -103,13 +109,15 @@ def test_pgdata_import_smoke( while True: relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") log.info( - f"relblock size: {relblock_size/8192} pages (target: {target_relblock_size//8192}) pages" + f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages" ) if relblock_size >= target_relblock_size: break addrows = int((target_relblock_size - relblock_size) // 8192) assert addrows >= 1, "forward progress" - vanilla_pg.safe_psql(f"insert into t select generate_series({nrows+1}, {nrows + addrows})") + vanilla_pg.safe_psql( + f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" + ) nrows += addrows expect_nrows = nrows expect_sum = ( @@ -332,6 +340,224 @@ def test_pgdata_import_smoke( br_initdb_endpoint.safe_psql("select * from othertable") +def test_fast_import_with_pageserver_ingest( + test_output_dir, + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + fast_import: FastImport, + pg_distrib_dir: Path, + pg_version: PgVersion, + mock_s3_server: MockS3Server, + mock_kms: KMSClient, + mock_s3_client: S3Client, + neon_env_builder: NeonEnvBuilder, + make_httpserver: HTTPServer, +): + # Prepare KMS and S3 + key_response = mock_kms.create_key( + Description="Test key", + KeyUsage="ENCRYPT_DECRYPT", + Origin="AWS_KMS", + ) + key_id = key_response["KeyMetadata"]["KeyId"] + + def encrypt(x: str) -> EncryptResponseTypeDef: + return mock_kms.encrypt(KeyId=key_id, Plaintext=x) + + # Start source postgres and ingest data + vanilla_pg.start() + vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") + + # Setup pageserver and fake cplane for import progress + def handler(request: Request) -> Response: + log.info(f"control plane request: {request.json}") + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler) + + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) + env = neon_env_builder.init_start() + + env.pageserver.patch_config_toml_nonrecursive( + { + "import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api", + # because import_pgdata code uses this endpoint, not the one in common remote storage config + # TODO: maybe use common remote_storage config in pageserver? + "import_pgdata_aws_endpoint_url": env.s3_mock_server.endpoint(), + } + ) + env.pageserver.stop() + env.pageserver.start() + + # Encrypt connstrings and put spec into S3 + source_connstring_encrypted = encrypt(vanilla_pg.connstr()) + spec = { + "encryption_secret": {"KMS": {"key_id": key_id}}, + "source_connstring_ciphertext_base64": base64.b64encode( + source_connstring_encrypted["CiphertextBlob"] + ).decode("utf-8"), + "project_id": "someproject", + "branch_id": "somebranch", + } + + bucket = "test-bucket" + key_prefix = "test-prefix" + mock_s3_client.create_bucket(Bucket=bucket) + mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)) + + # Create timeline with import_pgdata + tenant_id = TenantId.generate() + env.storage_controller.tenant_create(tenant_id) + + timeline_id = TimelineId.generate() + log.info("starting import") + start = time.monotonic() + + idempotency = ImportPgdataIdemptencyKey.random() + log.info(f"idempotency key {idempotency}") + # TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop + # and check for 429 + + import_branch_name = "imported" + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": { + "AwsS3": { + "region": env.s3_mock_server.region(), + "bucket": bucket, + "key": key_prefix, + } + }, + }, + }, + ) + env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) + + # Run fast_import + if fast_import.extra_env is None: + fast_import.extra_env = {} + fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() + fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() + fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() + fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() + fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() + fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + pg_port = port_distributor.get_port() + fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}") + vanilla_pg.stop() + + def validate_vanilla_equivalence(ep): + res = ep.safe_psql("SELECT count(*), sum(a) FROM foo;", dbname="neondb") + assert res[0] == (10, 55), f"got result: {res}" + + # Sanity check that data in pgdata is expected: + pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) + with VanillaPostgres( + fast_import.workdir / "pgdata", pgbin, pg_port, False + ) as new_pgdata_vanilla_pg: + new_pgdata_vanilla_pg.start() + + # database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres + conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb") + validate_vanilla_equivalence(conn) + + # Poll pageserver statuses in s3 + while True: + locations = env.storage_controller.locate(tenant_id) + active_count = 0 + for location in locations: + shard_id = TenantShardId.parse(location["shard_id"]) + ps = env.get_pageserver(location["node_id"]) + try: + detail = ps.http_client().timeline_detail(shard_id, timeline_id) + log.info(f"timeline {tenant_id}/{timeline_id} detail: {detail}") + state = detail["state"] + log.info(f"shard {shard_id} state: {state}") + if state == "Active": + active_count += 1 + except PageserverApiException as e: + if e.status_code == 404: + log.info("not found, import is in progress") + continue + elif e.status_code == 429: + log.info("import is in progress") + continue + else: + raise + + if state == "Active": + key = f"{key_prefix}/status/shard-{shard_id.shard_index}" + shard_status_file_contents = ( + mock_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8") + ) + shard_status = json.loads(shard_status_file_contents) + assert shard_status["done"] is True + + if active_count == len(locations): + log.info("all shards are active") + break + time.sleep(0.5) + + import_duration = time.monotonic() - start + log.info(f"import complete; duration={import_duration:.2f}s") + + ep = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id) + + # check that data is there + validate_vanilla_equivalence(ep) + + # check that we can do basic ops + + ep.safe_psql("create table othertable(values text)", dbname="neondb") + rw_lsn = Lsn(ep.safe_psql_scalar("select pg_current_wal_flush_lsn()")) + ep.stop() + + # ... at the tip + _ = env.create_branch( + new_branch_name="br-tip", + ancestor_branch_name=import_branch_name, + tenant_id=tenant_id, + ancestor_start_lsn=rw_lsn, + ) + br_tip_endpoint = env.endpoints.create_start( + branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id + ) + validate_vanilla_equivalence(br_tip_endpoint) + br_tip_endpoint.safe_psql("select * from othertable", dbname="neondb") + br_tip_endpoint.stop() + + # ... at the initdb lsn + locations = env.storage_controller.locate(tenant_id) + [shard_zero] = [ + loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0 + ] + shard_zero_ps = env.get_pageserver(shard_zero["node_id"]) + shard_zero_timeline_info = shard_zero_ps.http_client().timeline_detail( + shard_zero["shard_id"], timeline_id + ) + initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"]) + _ = env.create_branch( + new_branch_name="br-initdb", + ancestor_branch_name=import_branch_name, + tenant_id=tenant_id, + ancestor_start_lsn=initdb_lsn, + ) + br_initdb_endpoint = env.endpoints.create_start( + branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id + ) + validate_vanilla_equivalence(br_initdb_endpoint) + with pytest.raises(psycopg2.errors.UndefinedTable): + br_initdb_endpoint.safe_psql("select * from othertable", dbname="neondb") + br_initdb_endpoint.stop() + + env.pageserver.stop(immediate=True) + + def test_fast_import_binary( test_output_dir, vanilla_pg: VanillaPostgres, @@ -342,7 +568,7 @@ def test_fast_import_binary( vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") pg_port = port_distributor.get_port() - fast_import.run(pg_port, vanilla_pg.connstr()) + fast_import.run_pgdata(pg_port=pg_port, source_connection_string=vanilla_pg.connstr()) vanilla_pg.stop() pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version) @@ -358,6 +584,118 @@ def test_fast_import_binary( assert res[0][0] == 10 +def test_fast_import_restore_to_connstring( + test_output_dir, + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + fast_import: FastImport, + pg_distrib_dir: Path, + pg_version: PgVersion, +): + vanilla_pg.start() + vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") + + pgdatadir = test_output_dir / "destination-pgdata" + pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version) + port = port_distributor.get_port() + with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg: + destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"]) + destination_vanilla_pg.start() + + # create another database & role and try to restore there + destination_vanilla_pg.safe_psql(""" + CREATE ROLE testrole WITH + LOGIN + PASSWORD 'testpassword' + NOSUPERUSER + NOCREATEDB + NOCREATEROLE; + """) + destination_vanilla_pg.safe_psql("CREATE DATABASE testdb OWNER testrole;") + + destination_connstring = destination_vanilla_pg.connstr( + dbname="testdb", user="testrole", password="testpassword" + ) + fast_import.run_dump_restore( + source_connection_string=vanilla_pg.connstr(), + destination_connection_string=destination_connstring, + ) + vanilla_pg.stop() + conn = PgProtocol(dsn=destination_connstring) + res = conn.safe_psql("SELECT count(*) FROM foo;") + log.info(f"Result: {res}") + assert res[0][0] == 10 + + +def test_fast_import_restore_to_connstring_from_s3_spec( + test_output_dir, + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + fast_import: FastImport, + pg_distrib_dir: Path, + pg_version: PgVersion, + mock_s3_server: MockS3Server, + mock_kms: KMSClient, + mock_s3_client: S3Client, +): + # Prepare KMS and S3 + key_response = mock_kms.create_key( + Description="Test key", + KeyUsage="ENCRYPT_DECRYPT", + Origin="AWS_KMS", + ) + key_id = key_response["KeyMetadata"]["KeyId"] + + def encrypt(x: str) -> EncryptResponseTypeDef: + return mock_kms.encrypt(KeyId=key_id, Plaintext=x) + + # Start source postgres and ingest data + vanilla_pg.start() + vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") + + # Start target postgres + pgdatadir = test_output_dir / "destination-pgdata" + pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version) + port = port_distributor.get_port() + with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg: + destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"]) + destination_vanilla_pg.start() + + # Encrypt connstrings and put spec into S3 + source_connstring_encrypted = encrypt(vanilla_pg.connstr()) + destination_connstring_encrypted = encrypt(destination_vanilla_pg.connstr()) + spec = { + "encryption_secret": {"KMS": {"key_id": key_id}}, + "source_connstring_ciphertext_base64": base64.b64encode( + source_connstring_encrypted["CiphertextBlob"] + ).decode("utf-8"), + "destination_connstring_ciphertext_base64": base64.b64encode( + destination_connstring_encrypted["CiphertextBlob"] + ).decode("utf-8"), + } + + mock_s3_client.create_bucket(Bucket="test-bucket") + mock_s3_client.put_object( + Bucket="test-bucket", Key="test-prefix/spec.json", Body=json.dumps(spec) + ) + + # Run fast_import + if fast_import.extra_env is None: + fast_import.extra_env = {} + fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() + fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() + fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() + fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() + fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() + fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + fast_import.run_dump_restore(s3prefix="s3://test-bucket/test-prefix") + vanilla_pg.stop() + + res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;") + log.info(f"Result: {res}") + assert res[0][0] == 10 + + # TODO: Maybe test with pageserver? # 1. run whole neon env # 2. create timeline with some s3 path???