mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-17 21:20:37 +00:00
fast import: allow restore to provided connection string (#10407)
Within https://github.com/neondatabase/cloud/issues/22089 we decided that would be nice to start with import that runs dump-restore into a running compute (more on this [here](https://www.notion.so/neondatabase/2024-Jan-13-Migration-Assistant-Next-Steps-Proposal-Revised-17af189e004780228bdbcad13eeda93f?pvs=4#17af189e004780de816ccd9c13afd953)) We could do it by writing another tool or by extending existing `fast_import.rs`, we chose the latter. In this PR, I have added optional `restore_connection_string` as a cli arg and as a part of the json spec. If specified, the script will not run postgres and will just perform restore into provided connection string. TODO: - [x] fast_import.rs: - [x] cli arg in the fast_import.rs - [x] encoded connstring in json spec - [x] simplify `fn main` a little, take out too verbose stuff to some functions - [ ] ~~allow streaming from dump stdout to restore stdin~~ will do in a separate PR - [ ] ~~address https://github.com/neondatabase/neon/pull/10251#pullrequestreview-2551877845~~ will do in a separate PR - [x] tests: - [x] restore with cli arg in the fast_import.rs - [x] restore with encoded connstring in json spec in s3 - [ ] ~~test with custom dbname~~ will do in a separate PR - [ ] ~~test with s3 + pageserver + fast import binary~~ https://github.com/neondatabase/neon/pull/10487 - [ ] ~~https://github.com/neondatabase/neon/pull/10271#discussion_r1923715493~~ will do in a separate PR neondatabase/cloud#22775 --------- Co-authored-by: Eduard Dykman <bird.duskpoet@gmail.com>
This commit is contained in:
@@ -25,10 +25,10 @@
|
||||
//! docker push localhost:3030/localregistry/compute-node-v14:latest
|
||||
//! ```
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{bail, Context};
|
||||
use aws_config::BehaviorVersion;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, Subcommand};
|
||||
use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion};
|
||||
use nix::unistd::Pid;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -44,32 +44,59 @@ mod s3_uri;
|
||||
const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600);
|
||||
const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300);
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Command {
|
||||
/// Runs local postgres (neon binary), restores into it,
|
||||
/// uploads pgdata to s3 to be consumed by pageservers
|
||||
Pgdata {
|
||||
/// Raw connection string to the source database. Used only in tests,
|
||||
/// real scenario uses encrypted connection string in spec.json from s3.
|
||||
#[clap(long)]
|
||||
source_connection_string: Option<String>,
|
||||
/// If specified, will not shut down the local postgres after the import. Used in local testing
|
||||
#[clap(short, long)]
|
||||
interactive: bool,
|
||||
/// Port to run postgres on. Default is 5432.
|
||||
#[clap(long, default_value_t = 5432)]
|
||||
pg_port: u16, // port to run postgres on, 5432 is default
|
||||
|
||||
/// Number of CPUs in the system. This is used to configure # of
|
||||
/// parallel worker processes, for index creation.
|
||||
#[clap(long, env = "NEON_IMPORTER_NUM_CPUS")]
|
||||
num_cpus: Option<usize>,
|
||||
|
||||
/// Amount of RAM in the system. This is used to configure shared_buffers
|
||||
/// and maintenance_work_mem.
|
||||
#[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
|
||||
memory_mb: Option<usize>,
|
||||
},
|
||||
|
||||
/// Runs pg_dump-pg_restore from source to destination without running local postgres.
|
||||
DumpRestore {
|
||||
/// Raw connection string to the source database. Used only in tests,
|
||||
/// real scenario uses encrypted connection string in spec.json from s3.
|
||||
#[clap(long)]
|
||||
source_connection_string: Option<String>,
|
||||
/// Raw connection string to the destination database. Used only in tests,
|
||||
/// real scenario uses encrypted connection string in spec.json from s3.
|
||||
#[clap(long)]
|
||||
destination_connection_string: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
struct Args {
|
||||
#[clap(long)]
|
||||
#[clap(long, env = "NEON_IMPORTER_WORKDIR")]
|
||||
working_directory: Utf8PathBuf,
|
||||
#[clap(long, env = "NEON_IMPORTER_S3_PREFIX")]
|
||||
s3_prefix: Option<s3_uri::S3Uri>,
|
||||
#[clap(long)]
|
||||
source_connection_string: Option<String>,
|
||||
#[clap(short, long)]
|
||||
interactive: bool,
|
||||
#[clap(long)]
|
||||
#[clap(long, env = "NEON_IMPORTER_PG_BIN_DIR")]
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
#[clap(long)]
|
||||
#[clap(long, env = "NEON_IMPORTER_PG_LIB_DIR")]
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
#[clap(long)]
|
||||
pg_port: Option<u16>, // port to run postgres on, 5432 is default
|
||||
|
||||
/// Number of CPUs in the system. This is used to configure # of
|
||||
/// parallel worker processes, for index creation.
|
||||
#[clap(long, env = "NEON_IMPORTER_NUM_CPUS")]
|
||||
num_cpus: Option<usize>,
|
||||
|
||||
/// Amount of RAM in the system. This is used to configure shared_buffers
|
||||
/// and maintenance_work_mem.
|
||||
#[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
|
||||
memory_mb: Option<usize>,
|
||||
#[clap(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[serde_with::serde_as]
|
||||
@@ -78,6 +105,8 @@ struct Spec {
|
||||
encryption_secret: EncryptionSecret,
|
||||
#[serde_as(as = "serde_with::base64::Base64")]
|
||||
source_connstring_ciphertext_base64: Vec<u8>,
|
||||
#[serde_as(as = "Option<serde_with::base64::Base64>")]
|
||||
destination_connstring_ciphertext_base64: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
@@ -93,192 +122,151 @@ const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") {
|
||||
"C.UTF-8"
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
utils::logging::init(
|
||||
utils::logging::LogFormat::Plain,
|
||||
utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter,
|
||||
utils::logging::Output::Stdout,
|
||||
)?;
|
||||
|
||||
info!("starting");
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
// Validate arguments
|
||||
if args.s3_prefix.is_none() && args.source_connection_string.is_none() {
|
||||
anyhow::bail!("either s3_prefix or source_connection_string must be specified");
|
||||
}
|
||||
if args.s3_prefix.is_some() && args.source_connection_string.is_some() {
|
||||
anyhow::bail!("only one of s3_prefix or source_connection_string can be specified");
|
||||
}
|
||||
|
||||
let working_directory = args.working_directory;
|
||||
let pg_bin_dir = args.pg_bin_dir;
|
||||
let pg_lib_dir = args.pg_lib_dir;
|
||||
let pg_port = args.pg_port.unwrap_or_else(|| {
|
||||
info!("pg_port not specified, using default 5432");
|
||||
5432
|
||||
});
|
||||
|
||||
// Initialize AWS clients only if s3_prefix is specified
|
||||
let (aws_config, kms_client) = if args.s3_prefix.is_some() {
|
||||
let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
|
||||
let kms = aws_sdk_kms::Client::new(&config);
|
||||
(Some(config), Some(kms))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Get source connection string either from S3 spec or direct argument
|
||||
let source_connection_string = if let Some(s3_prefix) = &args.s3_prefix {
|
||||
let spec: Spec = {
|
||||
let spec_key = s3_prefix.append("/spec.json");
|
||||
let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap());
|
||||
let object = s3_client
|
||||
.get_object()
|
||||
.bucket(&spec_key.bucket)
|
||||
.key(spec_key.key)
|
||||
.send()
|
||||
.await
|
||||
.context("get spec from s3")?
|
||||
.body
|
||||
.collect()
|
||||
.await
|
||||
.context("download spec body")?;
|
||||
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
|
||||
};
|
||||
|
||||
match spec.encryption_secret {
|
||||
EncryptionSecret::KMS { key_id } => {
|
||||
let mut output = kms_client
|
||||
.unwrap()
|
||||
.decrypt()
|
||||
.key_id(key_id)
|
||||
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
|
||||
spec.source_connstring_ciphertext_base64,
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.context("decrypt source connection string")?;
|
||||
let plaintext = output
|
||||
.plaintext
|
||||
.take()
|
||||
.context("get plaintext source connection string")?;
|
||||
String::from_utf8(plaintext.into_inner())
|
||||
.context("parse source connection string as utf8")?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
args.source_connection_string.unwrap()
|
||||
};
|
||||
|
||||
match tokio::fs::create_dir(&working_directory).await {
|
||||
Ok(()) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
if !is_directory_empty(&working_directory)
|
||||
.await
|
||||
.context("check if working directory is empty")?
|
||||
{
|
||||
anyhow::bail!("working directory is not empty");
|
||||
} else {
|
||||
// ok
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
|
||||
}
|
||||
|
||||
let pgdata_dir = working_directory.join("pgdata");
|
||||
tokio::fs::create_dir(&pgdata_dir)
|
||||
async fn decode_connstring(
|
||||
kms_client: &aws_sdk_kms::Client,
|
||||
key_id: &String,
|
||||
connstring_ciphertext_base64: Vec<u8>,
|
||||
) -> Result<String, anyhow::Error> {
|
||||
let mut output = kms_client
|
||||
.decrypt()
|
||||
.key_id(key_id)
|
||||
.ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
|
||||
connstring_ciphertext_base64,
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.context("create pgdata directory")?;
|
||||
.context("decrypt connection string")?;
|
||||
|
||||
let pgbin = pg_bin_dir.join("postgres");
|
||||
let pg_version = match get_pg_version(pgbin.as_ref()) {
|
||||
PostgresMajorVersion::V14 => 14,
|
||||
PostgresMajorVersion::V15 => 15,
|
||||
PostgresMajorVersion::V16 => 16,
|
||||
PostgresMajorVersion::V17 => 17,
|
||||
};
|
||||
let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded
|
||||
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
|
||||
superuser,
|
||||
locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
|
||||
pg_version,
|
||||
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
|
||||
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
|
||||
pgdata: &pgdata_dir,
|
||||
})
|
||||
.await
|
||||
.context("initdb")?;
|
||||
let plaintext = output
|
||||
.plaintext
|
||||
.take()
|
||||
.context("get plaintext connection string")?;
|
||||
|
||||
// If the caller didn't specify CPU / RAM to use for sizing, default to
|
||||
// number of CPUs in the system, and pretty arbitrarily, 256 MB of RAM.
|
||||
let nproc = args.num_cpus.unwrap_or_else(num_cpus::get);
|
||||
let memory_mb = args.memory_mb.unwrap_or(256);
|
||||
String::from_utf8(plaintext.into_inner()).context("parse connection string as utf8")
|
||||
}
|
||||
|
||||
// Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for
|
||||
// maintenance_work_mem (i.e. for sorting during index creation), and leave the rest
|
||||
// available for misc other stuff that PostgreSQL uses memory for.
|
||||
let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize;
|
||||
let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize;
|
||||
struct PostgresProcess {
|
||||
pgdata_dir: Utf8PathBuf,
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
pgbin: Utf8PathBuf,
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
postgres_proc: Option<tokio::process::Child>,
|
||||
}
|
||||
|
||||
//
|
||||
// Launch postgres process
|
||||
//
|
||||
let mut postgres_proc = tokio::process::Command::new(pgbin)
|
||||
.arg("-D")
|
||||
.arg(&pgdata_dir)
|
||||
.args(["-p", &format!("{pg_port}")])
|
||||
.args(["-c", "wal_level=minimal"])
|
||||
.args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")])
|
||||
.args(["-c", "max_wal_senders=0"])
|
||||
.args(["-c", "fsync=off"])
|
||||
.args(["-c", "full_page_writes=off"])
|
||||
.args(["-c", "synchronous_commit=off"])
|
||||
.args([
|
||||
"-c",
|
||||
&format!("maintenance_work_mem={maintenance_work_mem_mb}MB"),
|
||||
])
|
||||
.args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
|
||||
.args(["-c", &format!("max_parallel_workers={nproc}")])
|
||||
.args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
|
||||
.args(["-c", &format!("max_worker_processes={nproc}")])
|
||||
.args([
|
||||
"-c",
|
||||
&format!(
|
||||
"effective_io_concurrency={}",
|
||||
if cfg!(target_os = "macos") { 0 } else { 100 }
|
||||
),
|
||||
])
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &pg_lib_dir)
|
||||
.env(
|
||||
"ASAN_OPTIONS",
|
||||
std::env::var("ASAN_OPTIONS").unwrap_or_default(),
|
||||
impl PostgresProcess {
|
||||
fn new(pgdata_dir: Utf8PathBuf, pg_bin_dir: Utf8PathBuf, pg_lib_dir: Utf8PathBuf) -> Self {
|
||||
Self {
|
||||
pgdata_dir,
|
||||
pgbin: pg_bin_dir.join("postgres"),
|
||||
pg_bin_dir,
|
||||
pg_lib_dir,
|
||||
postgres_proc: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn prepare(&self, initdb_user: &str) -> Result<(), anyhow::Error> {
|
||||
tokio::fs::create_dir(&self.pgdata_dir)
|
||||
.await
|
||||
.context("create pgdata directory")?;
|
||||
|
||||
let pg_version = match get_pg_version(self.pgbin.as_ref()) {
|
||||
PostgresMajorVersion::V14 => 14,
|
||||
PostgresMajorVersion::V15 => 15,
|
||||
PostgresMajorVersion::V16 => 16,
|
||||
PostgresMajorVersion::V17 => 17,
|
||||
};
|
||||
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
|
||||
superuser: initdb_user,
|
||||
locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
|
||||
pg_version,
|
||||
initdb_bin: self.pg_bin_dir.join("initdb").as_ref(),
|
||||
library_search_path: &self.pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
|
||||
pgdata: &self.pgdata_dir,
|
||||
})
|
||||
.await
|
||||
.context("initdb")
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&mut self,
|
||||
initdb_user: &str,
|
||||
port: u16,
|
||||
nproc: usize,
|
||||
memory_mb: usize,
|
||||
) -> Result<&tokio::process::Child, anyhow::Error> {
|
||||
self.prepare(initdb_user).await?;
|
||||
|
||||
// Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for
|
||||
// maintenance_work_mem (i.e. for sorting during index creation), and leave the rest
|
||||
// available for misc other stuff that PostgreSQL uses memory for.
|
||||
let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize;
|
||||
let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize;
|
||||
|
||||
//
|
||||
// Launch postgres process
|
||||
//
|
||||
let mut proc = tokio::process::Command::new(&self.pgbin)
|
||||
.arg("-D")
|
||||
.arg(&self.pgdata_dir)
|
||||
.args(["-p", &format!("{port}")])
|
||||
.args(["-c", "wal_level=minimal"])
|
||||
.args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")])
|
||||
.args(["-c", "shared_buffers=10GB"])
|
||||
.args(["-c", "max_wal_senders=0"])
|
||||
.args(["-c", "fsync=off"])
|
||||
.args(["-c", "full_page_writes=off"])
|
||||
.args(["-c", "synchronous_commit=off"])
|
||||
.args([
|
||||
"-c",
|
||||
&format!("maintenance_work_mem={maintenance_work_mem_mb}MB"),
|
||||
])
|
||||
.args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
|
||||
.args(["-c", &format!("max_parallel_workers={nproc}")])
|
||||
.args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
|
||||
.args(["-c", &format!("max_worker_processes={nproc}")])
|
||||
.args(["-c", "effective_io_concurrency=100"])
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &self.pg_lib_dir)
|
||||
.env(
|
||||
"ASAN_OPTIONS",
|
||||
std::env::var("ASAN_OPTIONS").unwrap_or_default(),
|
||||
)
|
||||
.env(
|
||||
"UBSAN_OPTIONS",
|
||||
std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
|
||||
)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.context("spawn postgres")?;
|
||||
|
||||
info!("spawned postgres, waiting for it to become ready");
|
||||
tokio::spawn(
|
||||
child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take())
|
||||
.instrument(info_span!("postgres")),
|
||||
);
|
||||
|
||||
self.postgres_proc = Some(proc);
|
||||
Ok(self.postgres_proc.as_ref().unwrap())
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
|
||||
let proc: &mut tokio::process::Child = self.postgres_proc.as_mut().unwrap();
|
||||
info!("shutdown postgres");
|
||||
nix::sys::signal::kill(
|
||||
Pid::from_raw(i32::try_from(proc.id().unwrap()).expect("convert child pid to i32")),
|
||||
nix::sys::signal::SIGTERM,
|
||||
)
|
||||
.env(
|
||||
"UBSAN_OPTIONS",
|
||||
std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
|
||||
)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.context("spawn postgres")?;
|
||||
|
||||
info!("spawned postgres, waiting for it to become ready");
|
||||
tokio::spawn(
|
||||
child_stdio_to_log::relay_process_output(
|
||||
postgres_proc.stdout.take(),
|
||||
postgres_proc.stderr.take(),
|
||||
)
|
||||
.instrument(info_span!("postgres")),
|
||||
);
|
||||
.context("signal postgres to shut down")?;
|
||||
proc.wait()
|
||||
.await
|
||||
.context("wait for postgres to shut down")
|
||||
.map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_ready(connstring: String, create_dbname: String) {
|
||||
// Create neondb database in the running postgres
|
||||
let restore_pg_connstring =
|
||||
format!("host=localhost port={pg_port} user={superuser} dbname=postgres");
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
@@ -289,7 +277,12 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
match tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await {
|
||||
match tokio_postgres::connect(
|
||||
&connstring.replace("dbname=neondb", "dbname=postgres"),
|
||||
tokio_postgres::NoTls,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((client, connection)) => {
|
||||
// Spawn the connection handling task to maintain the connection
|
||||
tokio::spawn(async move {
|
||||
@@ -298,9 +291,12 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
match client.simple_query("CREATE DATABASE neondb;").await {
|
||||
match client
|
||||
.simple_query(format!("CREATE DATABASE {create_dbname};").as_str())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!("created neondb database");
|
||||
info!("created {} database", create_dbname);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -324,10 +320,16 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let restore_pg_connstring = restore_pg_connstring.replace("dbname=postgres", "dbname=neondb");
|
||||
|
||||
let dumpdir = working_directory.join("dumpdir");
|
||||
async fn run_dump_restore(
|
||||
workdir: Utf8PathBuf,
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
source_connstring: String,
|
||||
destination_connstring: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let dumpdir = workdir.join("dumpdir");
|
||||
|
||||
let common_args = [
|
||||
// schema mapping (prob suffices to specify them on one side)
|
||||
@@ -356,7 +358,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
.arg("--no-sync")
|
||||
// POSITIONAL args
|
||||
// source db (db name included in connection string)
|
||||
.arg(&source_connection_string)
|
||||
.arg(&source_connstring)
|
||||
// how we run it
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &pg_lib_dir)
|
||||
@@ -376,19 +378,18 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
let st = pg_dump.wait().await.context("wait for pg_dump")?;
|
||||
info!(status=?st, "pg_dump exited");
|
||||
if !st.success() {
|
||||
warn!(status=%st, "pg_dump failed, restore will likely fail as well");
|
||||
error!(status=%st, "pg_dump failed, restore will likely fail as well");
|
||||
bail!("pg_dump failed");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: do it in a streaming way, plenty of internal research done on this already
|
||||
// TODO: maybe do it in a streaming way, plenty of internal research done on this already
|
||||
// TODO: do the unlogged table trick
|
||||
|
||||
info!("restore from working directory into vanilla postgres");
|
||||
{
|
||||
let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore"))
|
||||
.args(&common_args)
|
||||
.arg("-d")
|
||||
.arg(&restore_pg_connstring)
|
||||
.arg(&destination_connstring)
|
||||
// POSITIONAL args
|
||||
.arg(&dumpdir)
|
||||
// how we run it
|
||||
@@ -411,33 +412,82 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
let st = pg_restore.wait().await.context("wait for pg_restore")?;
|
||||
info!(status=?st, "pg_restore exited");
|
||||
if !st.success() {
|
||||
warn!(status=%st, "pg_restore failed, restore will likely fail as well");
|
||||
error!(status=%st, "pg_restore failed, restore will likely fail as well");
|
||||
bail!("pg_restore failed");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn cmd_pgdata(
|
||||
kms_client: Option<aws_sdk_kms::Client>,
|
||||
maybe_s3_prefix: Option<s3_uri::S3Uri>,
|
||||
maybe_spec: Option<Spec>,
|
||||
source_connection_string: Option<String>,
|
||||
interactive: bool,
|
||||
pg_port: u16,
|
||||
workdir: Utf8PathBuf,
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
num_cpus: Option<usize>,
|
||||
memory_mb: Option<usize>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
if maybe_spec.is_none() && source_connection_string.is_none() {
|
||||
bail!("spec must be provided for pgdata command");
|
||||
}
|
||||
if maybe_spec.is_some() && source_connection_string.is_some() {
|
||||
bail!("only one of spec or source_connection_string can be provided");
|
||||
}
|
||||
|
||||
let source_connection_string = if let Some(spec) = maybe_spec {
|
||||
match spec.encryption_secret {
|
||||
EncryptionSecret::KMS { key_id } => {
|
||||
decode_connstring(
|
||||
kms_client.as_ref().unwrap(),
|
||||
&key_id,
|
||||
spec.source_connstring_ciphertext_base64,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
source_connection_string.unwrap()
|
||||
};
|
||||
|
||||
let superuser = "cloud_admin";
|
||||
let destination_connstring = format!(
|
||||
"host=localhost port={} user={} dbname=neondb",
|
||||
pg_port, superuser
|
||||
);
|
||||
|
||||
let pgdata_dir = workdir.join("pgdata");
|
||||
let mut proc = PostgresProcess::new(pgdata_dir.clone(), pg_bin_dir.clone(), pg_lib_dir.clone());
|
||||
let nproc = num_cpus.unwrap_or_else(num_cpus::get);
|
||||
let memory_mb = memory_mb.unwrap_or(256);
|
||||
proc.start(superuser, pg_port, nproc, memory_mb).await?;
|
||||
wait_until_ready(destination_connstring.clone(), "neondb".to_string()).await;
|
||||
|
||||
run_dump_restore(
|
||||
workdir.clone(),
|
||||
pg_bin_dir,
|
||||
pg_lib_dir,
|
||||
source_connection_string,
|
||||
destination_connstring,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// If interactive mode, wait for Ctrl+C
|
||||
if args.interactive {
|
||||
if interactive {
|
||||
info!("Running in interactive mode. Press Ctrl+C to shut down.");
|
||||
tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
|
||||
}
|
||||
|
||||
info!("shutdown postgres");
|
||||
{
|
||||
nix::sys::signal::kill(
|
||||
Pid::from_raw(
|
||||
i32::try_from(postgres_proc.id().unwrap()).expect("convert child pid to i32"),
|
||||
),
|
||||
nix::sys::signal::SIGTERM,
|
||||
)
|
||||
.context("signal postgres to shut down")?;
|
||||
postgres_proc
|
||||
.wait()
|
||||
.await
|
||||
.context("wait for postgres to shut down")?;
|
||||
}
|
||||
proc.shutdown().await?;
|
||||
|
||||
// Only sync if s3_prefix was specified
|
||||
if let Some(s3_prefix) = args.s3_prefix {
|
||||
if let Some(s3_prefix) = maybe_s3_prefix {
|
||||
info!("upload pgdata");
|
||||
aws_s3_sync::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"))
|
||||
.await
|
||||
@@ -445,7 +495,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
|
||||
info!("write status");
|
||||
{
|
||||
let status_dir = working_directory.join("status");
|
||||
let status_dir = workdir.join("status");
|
||||
std::fs::create_dir(&status_dir).context("create status directory")?;
|
||||
let status_file = status_dir.join("pgdata");
|
||||
std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
|
||||
@@ -458,3 +508,153 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_dumprestore(
|
||||
kms_client: Option<aws_sdk_kms::Client>,
|
||||
maybe_spec: Option<Spec>,
|
||||
source_connection_string: Option<String>,
|
||||
destination_connection_string: Option<String>,
|
||||
workdir: Utf8PathBuf,
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
|
||||
match spec.encryption_secret {
|
||||
EncryptionSecret::KMS { key_id } => {
|
||||
let source = decode_connstring(
|
||||
kms_client.as_ref().unwrap(),
|
||||
&key_id,
|
||||
spec.source_connstring_ciphertext_base64,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let dest = if let Some(dest_ciphertext) =
|
||||
spec.destination_connstring_ciphertext_base64
|
||||
{
|
||||
decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext)
|
||||
.await?
|
||||
} else {
|
||||
bail!("destination connection string must be provided in spec for dump_restore command");
|
||||
};
|
||||
|
||||
(source, dest)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
source_connection_string.unwrap(),
|
||||
if let Some(val) = destination_connection_string {
|
||||
val
|
||||
} else {
|
||||
bail!("destination connection string must be provided for dump_restore command");
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
run_dump_restore(
|
||||
workdir,
|
||||
pg_bin_dir,
|
||||
pg_lib_dir,
|
||||
source_connstring,
|
||||
destination_connstring,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
utils::logging::init(
|
||||
utils::logging::LogFormat::Json,
|
||||
utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter,
|
||||
utils::logging::Output::Stdout,
|
||||
)?;
|
||||
|
||||
info!("starting");
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize AWS clients only if s3_prefix is specified
|
||||
let (aws_config, kms_client) = if args.s3_prefix.is_some() {
|
||||
let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
|
||||
let kms = aws_sdk_kms::Client::new(&config);
|
||||
(Some(config), Some(kms))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let spec: Option<Spec> = if let Some(s3_prefix) = &args.s3_prefix {
|
||||
let spec_key = s3_prefix.append("/spec.json");
|
||||
let s3_client = aws_sdk_s3::Client::new(aws_config.as_ref().unwrap());
|
||||
let object = s3_client
|
||||
.get_object()
|
||||
.bucket(&spec_key.bucket)
|
||||
.key(spec_key.key)
|
||||
.send()
|
||||
.await
|
||||
.context("get spec from s3")?
|
||||
.body
|
||||
.collect()
|
||||
.await
|
||||
.context("download spec body")?;
|
||||
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match tokio::fs::create_dir(&args.working_directory).await {
|
||||
Ok(()) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
if !is_directory_empty(&args.working_directory)
|
||||
.await
|
||||
.context("check if working directory is empty")?
|
||||
{
|
||||
bail!("working directory is not empty");
|
||||
} else {
|
||||
// ok
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
|
||||
}
|
||||
|
||||
match args.command {
|
||||
Command::Pgdata {
|
||||
source_connection_string,
|
||||
interactive,
|
||||
pg_port,
|
||||
num_cpus,
|
||||
memory_mb,
|
||||
} => {
|
||||
cmd_pgdata(
|
||||
kms_client,
|
||||
args.s3_prefix,
|
||||
spec,
|
||||
source_connection_string,
|
||||
interactive,
|
||||
pg_port,
|
||||
args.working_directory,
|
||||
args.pg_bin_dir,
|
||||
args.pg_lib_dir,
|
||||
num_cpus,
|
||||
memory_mb,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::DumpRestore {
|
||||
source_connection_string,
|
||||
destination_connection_string,
|
||||
} => {
|
||||
cmd_dumprestore(
|
||||
kms_client,
|
||||
spec,
|
||||
source_connection_string,
|
||||
destination_connection_string,
|
||||
args.working_directory,
|
||||
args.pg_bin_dir,
|
||||
args.pg_lib_dir,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
15
poetry.lock
generated
15
poetry.lock
generated
@@ -412,6 +412,7 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
botocore-stubs = "*"
|
||||
mypy-boto3-kms = {version = ">=1.26.0,<1.27.0", optional = true, markers = "extra == \"kms\""}
|
||||
mypy-boto3-s3 = {version = ">=1.26.0,<1.27.0", optional = true, markers = "extra == \"s3\""}
|
||||
types-s3transfer = "*"
|
||||
typing-extensions = ">=4.1.0"
|
||||
@@ -2022,6 +2023,18 @@ install-types = ["pip"]
|
||||
mypyc = ["setuptools (>=50)"]
|
||||
reports = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-boto3-kms"
|
||||
version = "1.26.147"
|
||||
description = "Type annotations for boto3.KMS 1.26.147 service generated with mypy-boto3-builder 7.14.5"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mypy-boto3-kms-1.26.147.tar.gz", hash = "sha256:816a4d1bb0585e1b9620a3f96c1d69a06f53b7b5621858579dd77c60dbb5fa5c"},
|
||||
{file = "mypy_boto3_kms-1.26.147-py3-none-any.whl", hash = "sha256:493f0db674a25c88769f5cb8ab8ac00d3dda5dfc903d5cda34c990ee64689f79"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-boto3-s3"
|
||||
version = "1.26.0.post1"
|
||||
@@ -3807,4 +3820,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "4dc3165fe22c0e0f7a030ea0f8a680ae2ff74561d8658c393abbe9112caaf5d7"
|
||||
content-hash = "03697c0a4d438ef088b0d397b8f0570aa3998ccf833fe612400824792498878b"
|
||||
|
||||
@@ -17,7 +17,7 @@ Jinja2 = "^3.1.5"
|
||||
types-requests = "^2.31.0.0"
|
||||
types-psycopg2 = "^2.9.21.20241019"
|
||||
boto3 = "^1.34.11"
|
||||
boto3-stubs = {extras = ["s3"], version = "^1.26.16"}
|
||||
boto3-stubs = {extras = ["s3", "kms"], version = "^1.26.16"}
|
||||
moto = {extras = ["server"], version = "^5.0.6"}
|
||||
backoff = "^2.2.1"
|
||||
pytest-lazy-fixture = "^0.6.3"
|
||||
|
||||
@@ -4,8 +4,10 @@ import subprocess
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_cli import AbstractNeonCli
|
||||
@@ -23,6 +25,7 @@ class FastImport(AbstractNeonCli):
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
workdir: Path,
|
||||
cleanup: bool = True,
|
||||
):
|
||||
if extra_env is None:
|
||||
env_vars = {}
|
||||
@@ -47,12 +50,43 @@ class FastImport(AbstractNeonCli):
|
||||
if not workdir.exists():
|
||||
raise Exception(f"Working directory '{workdir}' does not exist")
|
||||
self.workdir = workdir
|
||||
self.cleanup = cleanup
|
||||
|
||||
def run_pgdata(
|
||||
self,
|
||||
s3prefix: str | None = None,
|
||||
pg_port: int | None = None,
|
||||
source_connection_string: str | None = None,
|
||||
interactive: bool = False,
|
||||
):
|
||||
return self.run(
|
||||
"pgdata",
|
||||
s3prefix=s3prefix,
|
||||
pg_port=pg_port,
|
||||
source_connection_string=source_connection_string,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
def run_dump_restore(
|
||||
self,
|
||||
s3prefix: str | None = None,
|
||||
source_connection_string: str | None = None,
|
||||
destination_connection_string: str | None = None,
|
||||
):
|
||||
return self.run(
|
||||
"dump-restore",
|
||||
s3prefix=s3prefix,
|
||||
source_connection_string=source_connection_string,
|
||||
destination_connection_string=destination_connection_string,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
pg_port: int,
|
||||
source_connection_string: str | None = None,
|
||||
command: str,
|
||||
s3prefix: str | None = None,
|
||||
pg_port: int | None = None,
|
||||
source_connection_string: str | None = None,
|
||||
destination_connection_string: str | None = None,
|
||||
interactive: bool = False,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
if self.cmd is not None:
|
||||
@@ -60,13 +94,17 @@ class FastImport(AbstractNeonCli):
|
||||
args = [
|
||||
f"--pg-bin-dir={self.pg_bin}",
|
||||
f"--pg-lib-dir={self.pg_lib}",
|
||||
f"--pg-port={pg_port}",
|
||||
f"--working-directory={self.workdir}",
|
||||
]
|
||||
if source_connection_string is not None:
|
||||
args.append(f"--source-connection-string={source_connection_string}")
|
||||
if s3prefix is not None:
|
||||
args.append(f"--s3-prefix={s3prefix}")
|
||||
args.append(command)
|
||||
if pg_port is not None:
|
||||
args.append(f"--pg-port={pg_port}")
|
||||
if source_connection_string is not None:
|
||||
args.append(f"--source-connection-string={source_connection_string}")
|
||||
if destination_connection_string is not None:
|
||||
args.append(f"--destination-connection-string={destination_connection_string}")
|
||||
if interactive:
|
||||
args.append("--interactive")
|
||||
|
||||
@@ -77,7 +115,7 @@ class FastImport(AbstractNeonCli):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.workdir.exists():
|
||||
if self.workdir.exists() and self.cleanup:
|
||||
shutil.rmtree(self.workdir)
|
||||
|
||||
|
||||
@@ -87,9 +125,17 @@ def fast_import(
|
||||
test_output_dir: Path,
|
||||
neon_binpath: Path,
|
||||
pg_distrib_dir: Path,
|
||||
pytestconfig: Config,
|
||||
) -> Iterator[FastImport]:
|
||||
workdir = Path(tempfile.mkdtemp())
|
||||
with FastImport(None, neon_binpath, pg_distrib_dir, pg_version, workdir) as fi:
|
||||
workdir = Path(tempfile.mkdtemp(dir=test_output_dir, prefix="fast_import_"))
|
||||
with FastImport(
|
||||
None,
|
||||
neon_binpath,
|
||||
pg_distrib_dir,
|
||||
pg_version,
|
||||
workdir,
|
||||
cleanup=not cast(bool, pytestconfig.getoption("--preserve-database-files")),
|
||||
) as fi:
|
||||
yield fi
|
||||
|
||||
if fi.cmd is None:
|
||||
|
||||
@@ -27,6 +27,7 @@ from urllib.parse import quote, urlparse
|
||||
|
||||
import asyncpg
|
||||
import backoff
|
||||
import boto3
|
||||
import httpx
|
||||
import psycopg2
|
||||
import psycopg2.sql
|
||||
@@ -37,6 +38,8 @@ from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from jwcrypto import jwk
|
||||
from mypy_boto3_kms import KMSClient
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
@@ -199,6 +202,30 @@ def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]:
|
||||
mock_s3_server.kill()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_kms(mock_s3_server: MockS3Server) -> Iterator[KMSClient]:
|
||||
yield boto3.client(
|
||||
"kms",
|
||||
endpoint_url=mock_s3_server.endpoint(),
|
||||
region_name=mock_s3_server.region(),
|
||||
aws_access_key_id=mock_s3_server.access_key(),
|
||||
aws_secret_access_key=mock_s3_server.secret_key(),
|
||||
aws_session_token=mock_s3_server.session_token(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_s3_client(mock_s3_server: MockS3Server) -> Iterator[S3Client]:
|
||||
yield boto3.client(
|
||||
"s3",
|
||||
endpoint_url=mock_s3_server.endpoint(),
|
||||
region_name=mock_s3_server.region(),
|
||||
aws_access_key_id=mock_s3_server.access_key(),
|
||||
aws_secret_access_key=mock_s3_server.secret_key(),
|
||||
aws_session_token=mock_s3_server.session_token(),
|
||||
)
|
||||
|
||||
|
||||
class PgProtocol:
|
||||
"""Reusable connection logic"""
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
@@ -14,8 +16,12 @@ from fixtures.pageserver.http import (
|
||||
ImportPgdataIdemptencyKey,
|
||||
PageserverApiException,
|
||||
)
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
from fixtures.remote_storage import RemoteStorageKind
|
||||
from fixtures.remote_storage import MockS3Server, RemoteStorageKind
|
||||
from mypy_boto3_kms import KMSClient
|
||||
from mypy_boto3_kms.type_defs import EncryptResponseTypeDef
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
@@ -103,13 +109,15 @@ def test_pgdata_import_smoke(
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size/8192} pages (target: {target_relblock_size//8192}) pages"
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= target_relblock_size:
|
||||
break
|
||||
addrows = int((target_relblock_size - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(f"insert into t select generate_series({nrows+1}, {nrows + addrows})")
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
expect_nrows = nrows
|
||||
expect_sum = (
|
||||
@@ -332,6 +340,224 @@ def test_pgdata_import_smoke(
|
||||
br_initdb_endpoint.safe_psql("select * from othertable")
|
||||
|
||||
|
||||
def test_fast_import_with_pageserver_ingest(
|
||||
test_output_dir,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
port_distributor: PortDistributor,
|
||||
fast_import: FastImport,
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
mock_s3_server: MockS3Server,
|
||||
mock_kms: KMSClient,
|
||||
mock_s3_client: S3Client,
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
make_httpserver: HTTPServer,
|
||||
):
|
||||
# Prepare KMS and S3
|
||||
key_response = mock_kms.create_key(
|
||||
Description="Test key",
|
||||
KeyUsage="ENCRYPT_DECRYPT",
|
||||
Origin="AWS_KMS",
|
||||
)
|
||||
key_id = key_response["KeyMetadata"]["KeyId"]
|
||||
|
||||
def encrypt(x: str) -> EncryptResponseTypeDef:
|
||||
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
||||
|
||||
# Start source postgres and ingest data
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
||||
|
||||
# Setup pageserver and fake cplane for import progress
|
||||
def handler(request: Request) -> Response:
|
||||
log.info(f"control plane request: {request.json}")
|
||||
return Response(json.dumps({}), status=200)
|
||||
|
||||
cplane_mgmt_api_server = make_httpserver
|
||||
cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler)
|
||||
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
env.pageserver.patch_config_toml_nonrecursive(
|
||||
{
|
||||
"import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api",
|
||||
# because import_pgdata code uses this endpoint, not the one in common remote storage config
|
||||
# TODO: maybe use common remote_storage config in pageserver?
|
||||
"import_pgdata_aws_endpoint_url": env.s3_mock_server.endpoint(),
|
||||
}
|
||||
)
|
||||
env.pageserver.stop()
|
||||
env.pageserver.start()
|
||||
|
||||
# Encrypt connstrings and put spec into S3
|
||||
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
|
||||
spec = {
|
||||
"encryption_secret": {"KMS": {"key_id": key_id}},
|
||||
"source_connstring_ciphertext_base64": base64.b64encode(
|
||||
source_connstring_encrypted["CiphertextBlob"]
|
||||
).decode("utf-8"),
|
||||
"project_id": "someproject",
|
||||
"branch_id": "somebranch",
|
||||
}
|
||||
|
||||
bucket = "test-bucket"
|
||||
key_prefix = "test-prefix"
|
||||
mock_s3_client.create_bucket(Bucket=bucket)
|
||||
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
|
||||
|
||||
# Create timeline with import_pgdata
|
||||
tenant_id = TenantId.generate()
|
||||
env.storage_controller.tenant_create(tenant_id)
|
||||
|
||||
timeline_id = TimelineId.generate()
|
||||
log.info("starting import")
|
||||
start = time.monotonic()
|
||||
|
||||
idempotency = ImportPgdataIdemptencyKey.random()
|
||||
log.info(f"idempotency key {idempotency}")
|
||||
# TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop
|
||||
# and check for 429
|
||||
|
||||
import_branch_name = "imported"
|
||||
env.storage_controller.timeline_create(
|
||||
tenant_id,
|
||||
{
|
||||
"new_timeline_id": str(timeline_id),
|
||||
"import_pgdata": {
|
||||
"idempotency_key": str(idempotency),
|
||||
"location": {
|
||||
"AwsS3": {
|
||||
"region": env.s3_mock_server.region(),
|
||||
"bucket": bucket,
|
||||
"key": key_prefix,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
|
||||
|
||||
# Run fast_import
|
||||
if fast_import.extra_env is None:
|
||||
fast_import.extra_env = {}
|
||||
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
|
||||
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
|
||||
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
|
||||
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
|
||||
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
|
||||
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
|
||||
pg_port = port_distributor.get_port()
|
||||
fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}")
|
||||
vanilla_pg.stop()
|
||||
|
||||
def validate_vanilla_equivalence(ep):
|
||||
res = ep.safe_psql("SELECT count(*), sum(a) FROM foo;", dbname="neondb")
|
||||
assert res[0] == (10, 55), f"got result: {res}"
|
||||
|
||||
# Sanity check that data in pgdata is expected:
|
||||
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
|
||||
with VanillaPostgres(
|
||||
fast_import.workdir / "pgdata", pgbin, pg_port, False
|
||||
) as new_pgdata_vanilla_pg:
|
||||
new_pgdata_vanilla_pg.start()
|
||||
|
||||
# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
|
||||
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
|
||||
validate_vanilla_equivalence(conn)
|
||||
|
||||
# Poll pageserver statuses in s3
|
||||
while True:
|
||||
locations = env.storage_controller.locate(tenant_id)
|
||||
active_count = 0
|
||||
for location in locations:
|
||||
shard_id = TenantShardId.parse(location["shard_id"])
|
||||
ps = env.get_pageserver(location["node_id"])
|
||||
try:
|
||||
detail = ps.http_client().timeline_detail(shard_id, timeline_id)
|
||||
log.info(f"timeline {tenant_id}/{timeline_id} detail: {detail}")
|
||||
state = detail["state"]
|
||||
log.info(f"shard {shard_id} state: {state}")
|
||||
if state == "Active":
|
||||
active_count += 1
|
||||
except PageserverApiException as e:
|
||||
if e.status_code == 404:
|
||||
log.info("not found, import is in progress")
|
||||
continue
|
||||
elif e.status_code == 429:
|
||||
log.info("import is in progress")
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if state == "Active":
|
||||
key = f"{key_prefix}/status/shard-{shard_id.shard_index}"
|
||||
shard_status_file_contents = (
|
||||
mock_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
|
||||
)
|
||||
shard_status = json.loads(shard_status_file_contents)
|
||||
assert shard_status["done"] is True
|
||||
|
||||
if active_count == len(locations):
|
||||
log.info("all shards are active")
|
||||
break
|
||||
time.sleep(0.5)
|
||||
|
||||
import_duration = time.monotonic() - start
|
||||
log.info(f"import complete; duration={import_duration:.2f}s")
|
||||
|
||||
ep = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
|
||||
|
||||
# check that data is there
|
||||
validate_vanilla_equivalence(ep)
|
||||
|
||||
# check that we can do basic ops
|
||||
|
||||
ep.safe_psql("create table othertable(values text)", dbname="neondb")
|
||||
rw_lsn = Lsn(ep.safe_psql_scalar("select pg_current_wal_flush_lsn()"))
|
||||
ep.stop()
|
||||
|
||||
# ... at the tip
|
||||
_ = env.create_branch(
|
||||
new_branch_name="br-tip",
|
||||
ancestor_branch_name=import_branch_name,
|
||||
tenant_id=tenant_id,
|
||||
ancestor_start_lsn=rw_lsn,
|
||||
)
|
||||
br_tip_endpoint = env.endpoints.create_start(
|
||||
branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id
|
||||
)
|
||||
validate_vanilla_equivalence(br_tip_endpoint)
|
||||
br_tip_endpoint.safe_psql("select * from othertable", dbname="neondb")
|
||||
br_tip_endpoint.stop()
|
||||
|
||||
# ... at the initdb lsn
|
||||
locations = env.storage_controller.locate(tenant_id)
|
||||
[shard_zero] = [
|
||||
loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0
|
||||
]
|
||||
shard_zero_ps = env.get_pageserver(shard_zero["node_id"])
|
||||
shard_zero_timeline_info = shard_zero_ps.http_client().timeline_detail(
|
||||
shard_zero["shard_id"], timeline_id
|
||||
)
|
||||
initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"])
|
||||
_ = env.create_branch(
|
||||
new_branch_name="br-initdb",
|
||||
ancestor_branch_name=import_branch_name,
|
||||
tenant_id=tenant_id,
|
||||
ancestor_start_lsn=initdb_lsn,
|
||||
)
|
||||
br_initdb_endpoint = env.endpoints.create_start(
|
||||
branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id
|
||||
)
|
||||
validate_vanilla_equivalence(br_initdb_endpoint)
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
br_initdb_endpoint.safe_psql("select * from othertable", dbname="neondb")
|
||||
br_initdb_endpoint.stop()
|
||||
|
||||
env.pageserver.stop(immediate=True)
|
||||
|
||||
|
||||
def test_fast_import_binary(
|
||||
test_output_dir,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
@@ -342,7 +568,7 @@ def test_fast_import_binary(
|
||||
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
||||
|
||||
pg_port = port_distributor.get_port()
|
||||
fast_import.run(pg_port, vanilla_pg.connstr())
|
||||
fast_import.run_pgdata(pg_port=pg_port, source_connection_string=vanilla_pg.connstr())
|
||||
vanilla_pg.stop()
|
||||
|
||||
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
|
||||
@@ -358,6 +584,118 @@ def test_fast_import_binary(
|
||||
assert res[0][0] == 10
|
||||
|
||||
|
||||
def test_fast_import_restore_to_connstring(
|
||||
test_output_dir,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
port_distributor: PortDistributor,
|
||||
fast_import: FastImport,
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
):
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
||||
|
||||
pgdatadir = test_output_dir / "destination-pgdata"
|
||||
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
||||
port = port_distributor.get_port()
|
||||
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
|
||||
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
|
||||
destination_vanilla_pg.start()
|
||||
|
||||
# create another database & role and try to restore there
|
||||
destination_vanilla_pg.safe_psql("""
|
||||
CREATE ROLE testrole WITH
|
||||
LOGIN
|
||||
PASSWORD 'testpassword'
|
||||
NOSUPERUSER
|
||||
NOCREATEDB
|
||||
NOCREATEROLE;
|
||||
""")
|
||||
destination_vanilla_pg.safe_psql("CREATE DATABASE testdb OWNER testrole;")
|
||||
|
||||
destination_connstring = destination_vanilla_pg.connstr(
|
||||
dbname="testdb", user="testrole", password="testpassword"
|
||||
)
|
||||
fast_import.run_dump_restore(
|
||||
source_connection_string=vanilla_pg.connstr(),
|
||||
destination_connection_string=destination_connstring,
|
||||
)
|
||||
vanilla_pg.stop()
|
||||
conn = PgProtocol(dsn=destination_connstring)
|
||||
res = conn.safe_psql("SELECT count(*) FROM foo;")
|
||||
log.info(f"Result: {res}")
|
||||
assert res[0][0] == 10
|
||||
|
||||
|
||||
def test_fast_import_restore_to_connstring_from_s3_spec(
|
||||
test_output_dir,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
port_distributor: PortDistributor,
|
||||
fast_import: FastImport,
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
mock_s3_server: MockS3Server,
|
||||
mock_kms: KMSClient,
|
||||
mock_s3_client: S3Client,
|
||||
):
|
||||
# Prepare KMS and S3
|
||||
key_response = mock_kms.create_key(
|
||||
Description="Test key",
|
||||
KeyUsage="ENCRYPT_DECRYPT",
|
||||
Origin="AWS_KMS",
|
||||
)
|
||||
key_id = key_response["KeyMetadata"]["KeyId"]
|
||||
|
||||
def encrypt(x: str) -> EncryptResponseTypeDef:
|
||||
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
||||
|
||||
# Start source postgres and ingest data
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
||||
|
||||
# Start target postgres
|
||||
pgdatadir = test_output_dir / "destination-pgdata"
|
||||
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
||||
port = port_distributor.get_port()
|
||||
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
|
||||
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
|
||||
destination_vanilla_pg.start()
|
||||
|
||||
# Encrypt connstrings and put spec into S3
|
||||
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
|
||||
destination_connstring_encrypted = encrypt(destination_vanilla_pg.connstr())
|
||||
spec = {
|
||||
"encryption_secret": {"KMS": {"key_id": key_id}},
|
||||
"source_connstring_ciphertext_base64": base64.b64encode(
|
||||
source_connstring_encrypted["CiphertextBlob"]
|
||||
).decode("utf-8"),
|
||||
"destination_connstring_ciphertext_base64": base64.b64encode(
|
||||
destination_connstring_encrypted["CiphertextBlob"]
|
||||
).decode("utf-8"),
|
||||
}
|
||||
|
||||
mock_s3_client.create_bucket(Bucket="test-bucket")
|
||||
mock_s3_client.put_object(
|
||||
Bucket="test-bucket", Key="test-prefix/spec.json", Body=json.dumps(spec)
|
||||
)
|
||||
|
||||
# Run fast_import
|
||||
if fast_import.extra_env is None:
|
||||
fast_import.extra_env = {}
|
||||
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
|
||||
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
|
||||
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
|
||||
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
|
||||
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
|
||||
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
|
||||
fast_import.run_dump_restore(s3prefix="s3://test-bucket/test-prefix")
|
||||
vanilla_pg.stop()
|
||||
|
||||
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
|
||||
log.info(f"Result: {res}")
|
||||
assert res[0][0] == 10
|
||||
|
||||
|
||||
# TODO: Maybe test with pageserver?
|
||||
# 1. run whole neon env
|
||||
# 2. create timeline with some s3 path???
|
||||
|
||||
Reference in New Issue
Block a user