(manual cherry-pick) fast import: optional restore_connstring, skip running postgres if specified

This commit is contained in:
Gleb Novikov
2025-01-16 19:08:42 +00:00
parent 0dda2ad7ec
commit 0e405444b2

View File

@@ -52,6 +52,8 @@ struct Args {
s3_prefix: Option<s3_uri::S3Uri>,
#[clap(long)]
source_connection_string: Option<String>,
#[clap(long)]
restore_connection_string: Option<String>, // will not run postgres if specified, will do pg_restore to this connection string
#[clap(short, long)]
interactive: bool,
#[clap(long)]
@@ -68,6 +70,8 @@ struct Spec {
encryption_secret: EncryptionSecret,
#[serde_as(as = "serde_with::base64::Base64")]
source_connstring_ciphertext_base64: Vec<u8>,
#[serde_as(as = "serde_with::base64::Base64")]
restore_connstring_ciphertext_base64: Option<Vec<u8>>,
}
#[derive(serde::Deserialize)]
@@ -83,6 +87,26 @@ const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") {
"C.UTF-8"
};
async fn decode_connstring(kms_client: &aws_sdk_kms::Client, key_id: &String, connstring_ciphertext_base64: Vec<u8>) -> Result<String, anyhow::Error> {
let mut output = kms_client
.decrypt()
.key_id(key_id)
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
connstring_ciphertext_base64,
))
.send()
.await
.context("decrypt connection string")?;
let plaintext = output
.plaintext
.take()
.context("get plaintext connection string")?;
String::from_utf8(plaintext.into_inner())
.context("parse connection string as utf8")
}
#[tokio::main]
pub(crate) async fn main() -> anyhow::Result<()> {
utils::logging::init(
@@ -106,12 +130,6 @@ pub(crate) async fn main() -> anyhow::Result<()> {
let working_directory = args.working_directory;
let pg_bin_dir = args.pg_bin_dir;
let pg_lib_dir = args.pg_lib_dir;
let pg_port = if args.pg_port.is_some() {
args.pg_port.unwrap()
} else {
info!("pg_port not specified, using default 5432");
5432
};
// Initialize AWS clients only if s3_prefix is specified
let (aws_config, kms_client) = if args.s3_prefix.is_some() {
@@ -122,8 +140,20 @@ pub(crate) async fn main() -> anyhow::Result<()> {
(None, None)
};
// Get source connection string either from S3 spec or direct argument
let source_connection_string = if let Some(s3_prefix) = &args.s3_prefix {
let superuser = "cloud_admin";
let pg_port = || {
if args.pg_port.is_some() {
args.pg_port.unwrap()
} else {
info!("pg_port not specified, using default 5432");
5432
}
};
let mut run_postgres = true;
// Get connection strings either from S3 spec or direct arguments
let (source_connstring, restore_connstring) = if let Some(s3_prefix) = &args.s3_prefix {
let spec: Spec = {
let spec_key = s3_prefix.append("/spec.json");
let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap());
@@ -143,28 +173,47 @@ pub(crate) async fn main() -> anyhow::Result<()> {
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
let mut output = kms_client
.unwrap()
.decrypt()
.key_id(key_id)
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
spec.source_connstring_ciphertext_base64,
))
.send()
.await
.context("decrypt source connection string")?;
let plaintext = output
.plaintext
.take()
.context("get plaintext source connection string")?;
String::from_utf8(plaintext.into_inner())
.context("parse source connection string as utf8")?
let source = decode_connstring(
kms_client.as_ref().unwrap(),
&key_id,
spec.source_connstring_ciphertext_base64
).await?;
let restore =
if let Some(restore_ciphertext) = spec.restore_connstring_ciphertext_base64 {
run_postgres = false;
decode_connstring(kms_client.as_ref().unwrap(), &key_id, restore_ciphertext).await?
} else {
// restoring to local postgres otherwise
format!(
"host=localhost port={} user={} dbname=neondb",
pg_port(),
superuser
)
};
(source, restore)
}
}
} else {
args.source_connection_string.unwrap()
(
args.source_connection_string.unwrap(),
if args.restore_connection_string.is_none() {
format!(
"host=localhost port={} user={} dbname=neondb",
pg_port(),
superuser
)
} else {
run_postgres = false;
args.restore_connection_string.unwrap()
},
)
};
let nproc = num_cpus::get();
// unused if run_postgres is false, but needed for shutdown
match tokio::fs::create_dir(&working_directory).await {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
@@ -179,125 +228,120 @@ pub(crate) async fn main() -> anyhow::Result<()> {
}
Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
}
let pgdata_dir = working_directory.join("pgdata");
tokio::fs::create_dir(&pgdata_dir)
.await
.context("create pgdata directory")?;
let pgbin = pg_bin_dir.join("postgres");
let pg_version = match get_pg_version(pgbin.as_ref()) {
PostgresMajorVersion::V14 => 14,
PostgresMajorVersion::V15 => 15,
PostgresMajorVersion::V16 => 16,
PostgresMajorVersion::V17 => 17,
};
let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
superuser,
locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
pg_version,
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
pgdata: &pgdata_dir,
})
.await
.context("initdb")?;
let postgres_proc = if run_postgres {
assert!(restore_connstring.contains("host=localhost"));
tokio::fs::create_dir(&pgdata_dir)
.await
.context("create pgdata directory")?;
let nproc = num_cpus::get();
let pgbin = pg_bin_dir.join("postgres");
let pg_version = match get_pg_version(pgbin.as_ref()) {
PostgresMajorVersion::V14 => 14,
PostgresMajorVersion::V15 => 15,
PostgresMajorVersion::V16 => 16,
PostgresMajorVersion::V17 => 17,
};
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
superuser,
locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
pg_version,
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
pgdata: &pgdata_dir,
})
.await
.context("initdb")?;
//
// Launch postgres process
//
let mut postgres_proc = tokio::process::Command::new(pgbin)
.arg("-D")
.arg(&pgdata_dir)
.args(["-p", &format!("{pg_port}")])
.args(["-c", "wal_level=minimal"])
.args(["-c", "shared_buffers=10GB"])
.args(["-c", "max_wal_senders=0"])
.args(["-c", "fsync=off"])
.args(["-c", "full_page_writes=off"])
.args(["-c", "synchronous_commit=off"])
.args(["-c", "maintenance_work_mem=8388608"])
.args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
.args(["-c", &format!("max_parallel_workers={nproc}")])
.args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
.args(["-c", &format!("max_worker_processes={nproc}")])
.args([
"-c",
&format!(
"effective_io_concurrency={}",
if cfg!(target_os = "macos") { 0 } else { 100 }
),
])
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("spawn postgres")?;
//
// Launch postgres process
//
let mut proc = tokio::process::Command::new(pgbin)
.arg("-D")
.arg(&pgdata_dir)
.args(["-p", &format!("{}", pg_port())])
.args(["-c", "wal_level=minimal"])
.args(["-c", "shared_buffers=10GB"])
.args(["-c", "max_wal_senders=0"])
.args(["-c", "fsync=off"])
.args(["-c", "full_page_writes=off"])
.args(["-c", "synchronous_commit=off"])
.args(["-c", "maintenance_work_mem=8388608"])
.args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
.args(["-c", &format!("max_parallel_workers={nproc}")])
.args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
.args(["-c", &format!("max_worker_processes={nproc}")])
.args(["-c", "effective_io_concurrency=100"])
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("spawn postgres")?;
info!("spawned postgres, waiting for it to become ready");
tokio::spawn(
child_stdio_to_log::relay_process_output(
postgres_proc.stdout.take(),
postgres_proc.stderr.take(),
)
.instrument(info_span!("postgres")),
);
info!("spawned postgres, waiting for it to become ready");
tokio::spawn(
child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take())
.instrument(info_span!("postgres")),
);
// Create neondb database in the running postgres
let restore_pg_connstring =
format!("host=localhost port={pg_port} user={superuser} dbname=postgres");
// Create neondb database in the running postgres
let start_time = std::time::Instant::now();
let start_time = std::time::Instant::now();
loop {
if start_time.elapsed() > PG_WAIT_TIMEOUT {
error!(
loop {
if start_time.elapsed() > PG_WAIT_TIMEOUT {
error!(
"timeout exceeded: failed to poll postgres and create database within 10 minutes"
);
std::process::exit(1);
}
std::process::exit(1);
}
match tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await {
Ok((client, connection)) => {
// Spawn the connection handling task to maintain the connection
tokio::spawn(async move {
if let Err(e) = connection.await {
warn!("connection error: {}", e);
}
});
match tokio_postgres::connect(
&restore_connstring.replace("dbname=neondb", "dbname=postgres"),
tokio_postgres::NoTls,
)
.await
{
Ok((client, connection)) => {
// Spawn the connection handling task to maintain the connection
tokio::spawn(async move {
if let Err(e) = connection.await {
warn!("connection error: {}", e);
}
});
match client.simple_query("CREATE DATABASE neondb;").await {
Ok(_) => {
info!("created neondb database");
break;
}
Err(e) => {
warn!(
match client.simple_query("CREATE DATABASE neondb;").await {
Ok(_) => {
info!("created neondb database");
break;
}
Err(e) => {
warn!(
"failed to create database: {}, retying in {}s",
e,
PG_WAIT_RETRY_INTERVAL.as_secs_f32()
);
tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
continue;
tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
continue;
}
}
}
}
Err(_) => {
info!(
"postgres not ready yet, retrying in {}s",
PG_WAIT_RETRY_INTERVAL.as_secs_f32()
);
tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
continue;
Err(_) => {
info!(
"postgres not ready yet, retrying in {}s",
PG_WAIT_RETRY_INTERVAL.as_secs_f32()
);
tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
continue;
},
}
}
}
let restore_pg_connstring = restore_pg_connstring.replace("dbname=postgres", "dbname=neondb");
Some(proc)
} else {
info!("restore_connection_string specified, not running postgres process");
None
};
let dumpdir = working_directory.join("dumpdir");
@@ -328,7 +372,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.arg("--no-sync")
// POSITIONAL args
// source db (db name included in connection string)
.arg(&source_connection_string)
.arg(&source_connstring)
// how we run it
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir)
@@ -354,13 +398,11 @@ pub(crate) async fn main() -> anyhow::Result<()> {
// TODO: do it in a streaming way, plenty of internal research done on this already
// TODO: do the unlogged table trick
info!("restore from working directory into vanilla postgres");
{
let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore"))
.args(&common_args)
.arg("-d")
.arg(&restore_pg_connstring)
.arg(&restore_connstring)
// POSITIONAL args
.arg(&dumpdir)
// how we run it
@@ -378,7 +420,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
pg_restore.stdout.take(),
pg_restore.stderr.take(),
)
.instrument(info_span!("pg_restore")),
.instrument(info_span!("pg_restore")),
);
let st = pg_restore.wait().await.context("wait for pg_restore")?;
info!(status=?st, "pg_restore exited");
@@ -387,45 +429,46 @@ pub(crate) async fn main() -> anyhow::Result<()> {
}
}
// If interactive mode, wait for Ctrl+C
if args.interactive {
info!("Running in interactive mode. Press Ctrl+C to shut down.");
tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
}
if let Some(mut proc) = postgres_proc {
// If interactive mode, wait for Ctrl+C
if args.interactive {
info!("Running in interactive mode. Press Ctrl+C to shut down.");
tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
}
info!("shutdown postgres");
{
nix::sys::signal::kill(
Pid::from_raw(
i32::try_from(postgres_proc.id().unwrap()).expect("convert child pid to i32"),
),
nix::sys::signal::SIGTERM,
)
.context("signal postgres to shut down")?;
postgres_proc
.wait()
.await
.context("wait for postgres to shut down")?;
}
info!("shutdown postgres");
{
nix::sys::signal::kill(
Pid::from_raw(
i32::try_from(proc.id().unwrap()).expect("convert child pid to i32"),
),
nix::sys::signal::SIGTERM,
)
.context("signal postgres to shut down")?;
proc.wait()
.await
.context("wait for postgres to shut down")?;
}
// Only sync if s3_prefix was specified
if let Some(s3_prefix) = args.s3_prefix {
info!("upload pgdata");
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
.await
.context("sync dump directory to destination")?;
// Only sync if s3_prefix was specified
if let Some(s3_prefix) = args.s3_prefix {
info!("upload pgdata");
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
.await
.context("sync dump directory to destination")?;
info!("write status");
{
let status_dir = working_directory.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
let status_file = status_dir.join("pgdata");
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
.context("write status file")?;
aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/"))
.await
.context("sync status directory to destination")?;
}
info!("write status");
{
let status_dir = working_directory.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
let status_file = status_dir.join("pgdata");
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
.context("write status file")?;
aws_s3_sync::sync(&status_dir, &s3_prefix.append("/status/"))
.await
.context("sync status directory to destination")?;
}
}
}
Ok(())