From 0e405444b2a80b4fa457a0930cfef20e9c26f24e Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Thu, 16 Jan 2025 19:08:42 +0000 Subject: [PATCH] (manual cherry-pick) fast import: optional restore_connstring, skip running postgres if specified --- compute_tools/src/bin/fast_import.rs | 375 +++++++++++++++------------ 1 file changed, 209 insertions(+), 166 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 275765bba4..346381f4d9 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -52,6 +52,8 @@ struct Args { s3_prefix: Option, #[clap(long)] source_connection_string: Option, + #[clap(long)] + restore_connection_string: Option, // will not run postgres if specified, will do pg_restore to this connection string #[clap(short, long)] interactive: bool, #[clap(long)] @@ -68,6 +70,8 @@ struct Spec { encryption_secret: EncryptionSecret, #[serde_as(as = "serde_with::base64::Base64")] source_connstring_ciphertext_base64: Vec, + #[serde_as(as = "serde_with::base64::Base64")] + restore_connstring_ciphertext_base64: Option>, } #[derive(serde::Deserialize)] @@ -83,6 +87,26 @@ const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") { "C.UTF-8" }; +async fn decode_connstring(kms_client: &aws_sdk_kms::Client, key_id: &String, connstring_ciphertext_base64: Vec) -> Result { + let mut output = kms_client + .decrypt() + .key_id(key_id) + .ciphertext_blob(aws_sdk_s3::primitives::Blob::new( + connstring_ciphertext_base64, + )) + .send() + .await + .context("decrypt connection string")?; + + let plaintext = output + .plaintext + .take() + .context("get plaintext connection string")?; + + String::from_utf8(plaintext.into_inner()) + .context("parse connection string as utf8") +} + #[tokio::main] pub(crate) async fn main() -> anyhow::Result<()> { utils::logging::init( @@ -106,12 +130,6 @@ pub(crate) async fn main() -> anyhow::Result<()> { let working_directory = args.working_directory; let pg_bin_dir = args.pg_bin_dir; let pg_lib_dir = args.pg_lib_dir; - let pg_port = if args.pg_port.is_some() { - args.pg_port.unwrap() - } else { - info!("pg_port not specified, using default 5432"); - 5432 - }; // Initialize AWS clients only if s3_prefix is specified let (aws_config, kms_client) = if args.s3_prefix.is_some() { @@ -122,8 +140,20 @@ pub(crate) async fn main() -> anyhow::Result<()> { (None, None) }; - // Get source connection string either from S3 spec or direct argument - let source_connection_string = if let Some(s3_prefix) = &args.s3_prefix { + let superuser = "cloud_admin"; + let pg_port = || { + if args.pg_port.is_some() { + args.pg_port.unwrap() + } else { + info!("pg_port not specified, using default 5432"); + 5432 + } + }; + + let mut run_postgres = true; + + // Get connection strings either from S3 spec or direct arguments + let (source_connstring, restore_connstring) = if let Some(s3_prefix) = &args.s3_prefix { let spec: Spec = { let spec_key = s3_prefix.append("/spec.json"); let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap()); @@ -143,28 +173,47 @@ pub(crate) async fn main() -> anyhow::Result<()> { match spec.encryption_secret { EncryptionSecret::KMS { key_id } => { - let mut output = kms_client - .unwrap() - .decrypt() - .key_id(key_id) - .ciphertext_blob(aws_sdk_s3::primitives::Blob::new( - spec.source_connstring_ciphertext_base64, - )) - .send() - .await - .context("decrypt source connection string")?; - let plaintext = output - .plaintext - .take() - .context("get plaintext source connection string")?; - String::from_utf8(plaintext.into_inner()) - .context("parse source connection string as utf8")? + let source = decode_connstring( + kms_client.as_ref().unwrap(), + &key_id, + spec.source_connstring_ciphertext_base64 + ).await?; + + let restore = + if let Some(restore_ciphertext) = spec.restore_connstring_ciphertext_base64 { + run_postgres = false; + decode_connstring(kms_client.as_ref().unwrap(), &key_id, restore_ciphertext).await? + } else { + // restoring to local postgres otherwise + format!( + "host=localhost port={} user={} dbname=neondb", + pg_port(), + superuser + ) + }; + + (source, restore) } } } else { - args.source_connection_string.unwrap() + ( + args.source_connection_string.unwrap(), + if args.restore_connection_string.is_none() { + format!( + "host=localhost port={} user={} dbname=neondb", + pg_port(), + superuser + ) + } else { + run_postgres = false; + args.restore_connection_string.unwrap() + }, + ) }; + let nproc = num_cpus::get(); + + // unused if run_postgres is false, but needed for shutdown match tokio::fs::create_dir(&working_directory).await { Ok(()) => {} Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { @@ -179,125 +228,120 @@ pub(crate) async fn main() -> anyhow::Result<()> { } Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), } - let pgdata_dir = working_directory.join("pgdata"); - tokio::fs::create_dir(&pgdata_dir) - .await - .context("create pgdata directory")?; - let pgbin = pg_bin_dir.join("postgres"); - let pg_version = match get_pg_version(pgbin.as_ref()) { - PostgresMajorVersion::V14 => 14, - PostgresMajorVersion::V15 => 15, - PostgresMajorVersion::V16 => 16, - PostgresMajorVersion::V17 => 17, - }; - let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded - postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { - superuser, - locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, - pg_version, - initdb_bin: pg_bin_dir.join("initdb").as_ref(), - library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. - pgdata: &pgdata_dir, - }) - .await - .context("initdb")?; + let postgres_proc = if run_postgres { + assert!(restore_connstring.contains("host=localhost")); + tokio::fs::create_dir(&pgdata_dir) + .await + .context("create pgdata directory")?; - let nproc = num_cpus::get(); + let pgbin = pg_bin_dir.join("postgres"); + let pg_version = match get_pg_version(pgbin.as_ref()) { + PostgresMajorVersion::V14 => 14, + PostgresMajorVersion::V15 => 15, + PostgresMajorVersion::V16 => 16, + PostgresMajorVersion::V17 => 17, + }; + postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { + superuser, + locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, + pg_version, + initdb_bin: pg_bin_dir.join("initdb").as_ref(), + library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. + pgdata: &pgdata_dir, + }) + .await + .context("initdb")?; - // - // Launch postgres process - // - let mut postgres_proc = tokio::process::Command::new(pgbin) - .arg("-D") - .arg(&pgdata_dir) - .args(["-p", &format!("{pg_port}")]) - .args(["-c", "wal_level=minimal"]) - .args(["-c", "shared_buffers=10GB"]) - .args(["-c", "max_wal_senders=0"]) - .args(["-c", "fsync=off"]) - .args(["-c", "full_page_writes=off"]) - .args(["-c", "synchronous_commit=off"]) - .args(["-c", "maintenance_work_mem=8388608"]) - .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")]) - .args(["-c", &format!("max_parallel_workers={nproc}")]) - .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")]) - .args(["-c", &format!("max_worker_processes={nproc}")]) - .args([ - "-c", - &format!( - "effective_io_concurrency={}", - if cfg!(target_os = "macos") { 0 } else { 100 } - ), - ]) - .env_clear() - .env("LD_LIBRARY_PATH", &pg_lib_dir) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .spawn() - .context("spawn postgres")?; + // + // Launch postgres process + // + let mut proc = tokio::process::Command::new(pgbin) + .arg("-D") + .arg(&pgdata_dir) + .args(["-p", &format!("{}", pg_port())]) + .args(["-c", "wal_level=minimal"]) + .args(["-c", "shared_buffers=10GB"]) + .args(["-c", "max_wal_senders=0"]) + .args(["-c", "fsync=off"]) + .args(["-c", "full_page_writes=off"]) + .args(["-c", "synchronous_commit=off"]) + .args(["-c", "maintenance_work_mem=8388608"]) + .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")]) + .args(["-c", &format!("max_worker_processes={nproc}")]) + .args(["-c", "effective_io_concurrency=100"]) + .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .context("spawn postgres")?; - info!("spawned postgres, waiting for it to become ready"); - tokio::spawn( - child_stdio_to_log::relay_process_output( - postgres_proc.stdout.take(), - postgres_proc.stderr.take(), - ) - .instrument(info_span!("postgres")), - ); + info!("spawned postgres, waiting for it to become ready"); + tokio::spawn( + child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take()) + .instrument(info_span!("postgres")), + ); - // Create neondb database in the running postgres - let restore_pg_connstring = - format!("host=localhost port={pg_port} user={superuser} dbname=postgres"); + // Create neondb database in the running postgres + let start_time = std::time::Instant::now(); - let start_time = std::time::Instant::now(); - - loop { - if start_time.elapsed() > PG_WAIT_TIMEOUT { - error!( + loop { + if start_time.elapsed() > PG_WAIT_TIMEOUT { + error!( "timeout exceeded: failed to poll postgres and create database within 10 minutes" ); - std::process::exit(1); - } + std::process::exit(1); + } - match tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await { - Ok((client, connection)) => { - // Spawn the connection handling task to maintain the connection - tokio::spawn(async move { - if let Err(e) = connection.await { - warn!("connection error: {}", e); - } - }); + match tokio_postgres::connect( + &restore_connstring.replace("dbname=neondb", "dbname=postgres"), + tokio_postgres::NoTls, + ) + .await + { + Ok((client, connection)) => { + // Spawn the connection handling task to maintain the connection + tokio::spawn(async move { + if let Err(e) = connection.await { + warn!("connection error: {}", e); + } + }); - match client.simple_query("CREATE DATABASE neondb;").await { - Ok(_) => { - info!("created neondb database"); - break; - } - Err(e) => { - warn!( + match client.simple_query("CREATE DATABASE neondb;").await { + Ok(_) => { + info!("created neondb database"); + break; + } + Err(e) => { + warn!( "failed to create database: {}, retying in {}s", e, PG_WAIT_RETRY_INTERVAL.as_secs_f32() ); - tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await; - continue; + tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await; + continue; + } } } - } - Err(_) => { - info!( - "postgres not ready yet, retrying in {}s", - PG_WAIT_RETRY_INTERVAL.as_secs_f32() - ); - tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await; - continue; + Err(_) => { + info!( + "postgres not ready yet, retrying in {}s", + PG_WAIT_RETRY_INTERVAL.as_secs_f32() + ); + tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await; + continue; + }, } } - } - - let restore_pg_connstring = restore_pg_connstring.replace("dbname=postgres", "dbname=neondb"); + Some(proc) + } else { + info!("restore_connection_string specified, not running postgres process"); + None + }; let dumpdir = working_directory.join("dumpdir"); @@ -328,7 +372,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { .arg("--no-sync") // POSITIONAL args // source db (db name included in connection string) - .arg(&source_connection_string) + .arg(&source_connstring) // how we run it .env_clear() .env("LD_LIBRARY_PATH", &pg_lib_dir) @@ -354,13 +398,11 @@ pub(crate) async fn main() -> anyhow::Result<()> { // TODO: do it in a streaming way, plenty of internal research done on this already // TODO: do the unlogged table trick - - info!("restore from working directory into vanilla postgres"); { let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore")) .args(&common_args) .arg("-d") - .arg(&restore_pg_connstring) + .arg(&restore_connstring) // POSITIONAL args .arg(&dumpdir) // how we run it @@ -378,7 +420,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { pg_restore.stdout.take(), pg_restore.stderr.take(), ) - .instrument(info_span!("pg_restore")), + .instrument(info_span!("pg_restore")), ); let st = pg_restore.wait().await.context("wait for pg_restore")?; info!(status=?st, "pg_restore exited"); @@ -387,45 +429,46 @@ pub(crate) async fn main() -> anyhow::Result<()> { } } - // If interactive mode, wait for Ctrl+C - if args.interactive { - info!("Running in interactive mode. Press Ctrl+C to shut down."); - tokio::signal::ctrl_c().await.context("wait for ctrl-c")?; - } + if let Some(mut proc) = postgres_proc { + // If interactive mode, wait for Ctrl+C + if args.interactive { + info!("Running in interactive mode. Press Ctrl+C to shut down."); + tokio::signal::ctrl_c().await.context("wait for ctrl-c")?; + } - info!("shutdown postgres"); - { - nix::sys::signal::kill( - Pid::from_raw( - i32::try_from(postgres_proc.id().unwrap()).expect("convert child pid to i32"), - ), - nix::sys::signal::SIGTERM, - ) - .context("signal postgres to shut down")?; - postgres_proc - .wait() - .await - .context("wait for postgres to shut down")?; - } + info!("shutdown postgres"); + { + nix::sys::signal::kill( + Pid::from_raw( + i32::try_from(proc.id().unwrap()).expect("convert child pid to i32"), + ), + nix::sys::signal::SIGTERM, + ) + .context("signal postgres to shut down")?; + proc.wait() + .await + .context("wait for postgres to shut down")?; + } - // Only sync if s3_prefix was specified - if let Some(s3_prefix) = args.s3_prefix { - info!("upload pgdata"); - aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/")) - .await - .context("sync dump directory to destination")?; + // Only sync if s3_prefix was specified + if let Some(s3_prefix) = args.s3_prefix { + info!("upload pgdata"); + aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/")) + .await + .context("sync dump directory to destination")?; - info!("write status"); - { - let status_dir = working_directory.join("status"); - std::fs::create_dir(&status_dir).context("create status directory")?; - let status_file = status_dir.join("pgdata"); - std::fs::write(&status_file, serde_json::json!({"done": true}).to_string()) - .context("write status file")?; - aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/")) - .await - .context("sync status directory to destination")?; - } + info!("write status"); + { + let status_dir = working_directory.join("status"); + std::fs::create_dir(&status_dir).context("create status directory")?; + let status_file = status_dir.join("pgdata"); + std::fs::write(&status_file, serde_json::json!({"done": true}).to_string()) + .context("write status file")?; + aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/")) + .await + .context("sync status directory to destination")?; + } + } } Ok(())