Skip custom extensions in fast import

This commit is contained in:
piercypixel
2025-05-13 11:39:33 +00:00
parent 5bdba70f7d
commit 4f4a96ea25
8 changed files with 152 additions and 31 deletions

8
.gitmodules vendored
View File

@@ -1,16 +1,16 @@
[submodule "vendor/postgres-v14"]
path = vendor/postgres-v14
url = https://github.com/neondatabase/postgres.git
branch = REL_14_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v14
[submodule "vendor/postgres-v15"]
path = vendor/postgres-v15
url = https://github.com/neondatabase/postgres.git
branch = REL_15_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v15
[submodule "vendor/postgres-v16"]
path = vendor/postgres-v16
url = https://github.com/neondatabase/postgres.git
branch = REL_16_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v16
[submodule "vendor/postgres-v17"]
path = vendor/postgres-v17
url = https://github.com/neondatabase/postgres.git
branch = REL_17_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v17

View File

@@ -70,6 +70,14 @@ enum Command {
/// and maintenance_work_mem.
#[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
memory_mb: Option<usize>,
/// List of schemas to dump.
#[clap(long)]
schema: Vec<String>,
/// List of extensions to dump.
#[clap(long)]
extension: Vec<String>,
},
/// Runs pg_dump-pg_restore from source to destination without running local postgres.
@@ -82,6 +90,12 @@ enum Command {
/// real scenario uses encrypted connection string in spec.json from s3.
#[clap(long)]
destination_connection_string: Option<String>,
/// List of schemas to dump.
#[clap(long)]
schema: Vec<String>,
/// List of extensions to dump.
#[clap(long)]
extension: Vec<String>,
},
}
@@ -117,6 +131,8 @@ struct Spec {
source_connstring_ciphertext_base64: Vec<u8>,
#[serde_as(as = "Option<serde_with::base64::Base64>")]
destination_connstring_ciphertext_base64: Option<Vec<u8>>,
schemas: Option<Vec<String>>,
extensions: Option<Vec<String>>,
}
#[derive(serde::Deserialize)]
@@ -337,6 +353,8 @@ async fn run_dump_restore(
pg_lib_dir: Utf8PathBuf,
source_connstring: String,
destination_connstring: String,
schemas: Vec<String>,
extensions: Vec<String>,
) -> Result<(), anyhow::Error> {
let dumpdir = workdir.join("dumpdir");
let num_jobs = num_cpus::get().to_string();
@@ -351,6 +369,7 @@ async fn run_dump_restore(
"--no-subscriptions".to_string(),
"--no-tablespaces".to_string(),
"--no-event-triggers".to_string(),
"--enable-row-security".to_string(),
// format
"--format".to_string(),
"directory".to_string(),
@@ -361,10 +380,36 @@ async fn run_dump_restore(
"--verbose".to_string(),
];
let mut pg_dump_args = vec![
// this makes sure any unsupported extensions are not included in the dump
// even if we don't specify supported extensions explicitly
"--extension".to_string(),
"plpgsql".to_string(),
];
// if no schemas are specified, try to import all schemas
if !schemas.is_empty() {
// always include public schema objects
// but never create the schema itself
// it already exists in any pg cluster by default
pg_dump_args.push("--schema-no-create".to_string());
pg_dump_args.push("public".to_string());
for schema in &schemas {
pg_dump_args.push("--schema".to_string());
pg_dump_args.push(schema.clone());
}
}
for extension in &extensions {
pg_dump_args.push("--extension".to_string());
pg_dump_args.push(extension.clone());
}
info!("dump into the working directory");
{
let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump"))
.args(&common_args)
.args(&pg_dump_args)
.arg("-f")
.arg(&dumpdir)
.arg("--no-sync")
@@ -455,6 +500,8 @@ async fn cmd_pgdata(
maybe_s3_prefix: Option<s3_uri::S3Uri>,
maybe_spec: Option<Spec>,
source_connection_string: Option<String>,
schemas: Vec<String>,
extensions: Vec<String>,
interactive: bool,
pg_port: u16,
workdir: Utf8PathBuf,
@@ -470,19 +517,25 @@ async fn cmd_pgdata(
bail!("only one of spec or source_connection_string can be provided");
}
let source_connection_string = if let Some(spec) = maybe_spec {
let (source_connection_string, schemas, extensions) = if let Some(spec) = maybe_spec {
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
decode_connstring(
let schemas = spec.schemas.unwrap_or(vec![]);
let extensions = spec.extensions.unwrap_or(vec![]);
let source = decode_connstring(
kms_client.as_ref().unwrap(),
&key_id,
spec.source_connstring_ciphertext_base64,
)
.await?
.await
.context("decrypt source connection string")?;
(source, schemas, extensions)
}
}
} else {
source_connection_string.unwrap()
(source_connection_string.unwrap(), schemas, extensions)
};
let superuser = "cloud_admin";
@@ -504,6 +557,8 @@ async fn cmd_pgdata(
pg_lib_dir,
source_connection_string,
destination_connstring,
schemas,
extensions,
)
.await?;
@@ -546,18 +601,26 @@ async fn cmd_pgdata(
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn cmd_dumprestore(
kms_client: Option<aws_sdk_kms::Client>,
maybe_spec: Option<Spec>,
source_connection_string: Option<String>,
destination_connection_string: Option<String>,
schemas: Vec<String>,
extensions: Vec<String>,
workdir: Utf8PathBuf,
pg_bin_dir: Utf8PathBuf,
pg_lib_dir: Utf8PathBuf,
) -> Result<(), anyhow::Error> {
let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
let (source_connstring, destination_connstring, schemas, extensions) = if let Some(spec) =
maybe_spec
{
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
let schemas = spec.schemas.unwrap_or(vec![]);
let extensions = spec.extensions.unwrap_or(vec![]);
let source = decode_connstring(
kms_client.as_ref().unwrap(),
&key_id,
@@ -578,18 +641,17 @@ async fn cmd_dumprestore(
);
};
(source, dest)
(source, dest, schemas, extensions)
}
}
} else {
(
source_connection_string.unwrap(),
if let Some(val) = destination_connection_string {
val
} else {
bail!("destination connection string must be provided for dump_restore command");
},
)
let dest = if let Some(val) = destination_connection_string {
val
} else {
bail!("destination connection string must be provided for dump_restore command");
};
(source_connection_string.unwrap(), dest, schemas, extensions)
};
run_dump_restore(
@@ -598,6 +660,8 @@ async fn cmd_dumprestore(
pg_lib_dir,
source_connstring,
destination_connstring,
schemas,
extensions,
)
.await
}
@@ -679,6 +743,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
pg_port,
num_cpus,
memory_mb,
schema,
extension,
} => {
cmd_pgdata(
s3_client.as_ref(),
@@ -686,6 +752,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
args.s3_prefix.clone(),
spec,
source_connection_string,
schema,
extension,
interactive,
pg_port,
args.working_directory.clone(),
@@ -699,12 +767,16 @@ pub(crate) async fn main() -> anyhow::Result<()> {
Command::DumpRestore {
source_connection_string,
destination_connection_string,
schema,
extension,
} => {
cmd_dumprestore(
kms_client,
spec,
source_connection_string,
destination_connection_string,
schema,
extension,
args.working_directory.clone(),
args.pg_bin_dir,
args.pg_lib_dir,

View File

@@ -9,7 +9,6 @@ from pathlib import Path
from threading import Event
import psycopg2
import psycopg2.errors
import pytest
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.fast_import import (
@@ -1070,15 +1069,41 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
# Start source postgres and ingest data
vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
vanilla_pg.start()
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
res = vanilla_pg.safe_psql("SHOW shared_preload_libraries;")
log.info(f"shared_preload_libraries: {res}")
res = vanilla_pg.safe_psql("SELECT name FROM pg_available_extensions;")
log.info(f"pg_available_extensions: {res}")
res = vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
log.info(f"pg_extension: {res}")
# Create a number of extensions, we only will dump selected ones
vanilla_pg.safe_psql("CREATE EXTENSION neon;")
vanilla_pg.safe_psql("CREATE EXTENSION neon_utils;")
vanilla_pg.safe_psql("CREATE EXTENSION pg_visibility;")
# Default schema is always dumped
vanilla_pg.safe_psql(
"CREATE TABLE public.foo (a int); INSERT INTO public.foo SELECT generate_series(1, 7);"
)
# Create a number of schemas, we only will dump selected ones
vanilla_pg.safe_psql("CREATE SCHEMA custom;")
vanilla_pg.safe_psql(
"CREATE TABLE custom.foo (a int); INSERT INTO custom.foo SELECT generate_series(1, 13);"
)
vanilla_pg.safe_psql("CREATE SCHEMA other;")
vanilla_pg.safe_psql(
"CREATE TABLE other.foo (a int); INSERT INTO other.foo SELECT generate_series(1, 42);"
)
# Start target postgres
pgdatadir = test_output_dir / "destination-pgdata"
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
port = port_distributor.get_port()
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
destination_vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
destination_vanilla_pg.start()
# Encrypt connstrings and put spec into S3
@@ -1092,6 +1117,8 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
"destination_connstring_ciphertext_base64": base64.b64encode(
destination_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
"schemas": ["custom"],
"extensions": ["plpgsql", "neon"],
}
bucket = "test-bucket"
@@ -1117,9 +1144,31 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
}, f"got status: {job_status}"
vanilla_pg.stop()
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM public.foo;")
log.info(f"Result: {res}")
assert res[0][0] == 10
assert res[0][0] == 7
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM custom.foo;")
log.info(f"Result: {res}")
assert res[0][0] == 13
# Check that other schema is not restored
with pytest.raises(psycopg2.errors.UndefinedTable):
destination_vanilla_pg.safe_psql("SELECT count(*) FROM other.foo;")
# Check that all schemas are listed correctly
res = destination_vanilla_pg.safe_psql("SELECT nspname FROM pg_namespace;")
log.info(f"Result: {res}")
schemas = [row[0] for row in res]
assert "other" not in schemas
# Check that only selected extensions are restored
res = destination_vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
log.info(f"Result: {res}")
assert len(res) == 2
extensions = set([str(row[0]) for row in res])
assert "plpgsql" in extensions
assert "neon" in extensions
def test_fast_import_restore_to_connstring_error_to_s3_bad_destination(

View File

@@ -1,18 +1,18 @@
{
"v17": [
"17.5",
"8be779fd3ab9e87206da96a7e4842ef1abf04f44"
"10c002910447b3138e13213befca662df7cbe1d0"
],
"v16": [
"16.9",
"0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198"
"94ad7e11cd43cce32f5af5674af29b3f334551a7"
],
"v15": [
"15.13",
"de7640f55da07512834d5cc40c4b3fb376b5f04f"
"cd0b534a761c18d8ef4654d4f749c63c5663215f"
],
"v14": [
"14.18",
"55c0d45abe6467c02084c2192bca117eda6ce1e7"
"b1e9959858f0529ea33d3cc5e833c0acc43f583a"
]
}