fast_import: put job status to s3 (#11284)

## Problem

`fast_import` binary is being run inside neonvms, and they do not
support proper `kubectl describe logs` now, there are a bunch of other
caveats as well: https://github.com/neondatabase/autoscaling/issues/1320

Anyway, we needed a signal if job finished successfully or not, and if
not — at least some error message for the cplane operation. And after [a
short
discussion](https://neondb.slack.com/archives/C07PG8J1L0P/p1741954251813609),
that s3 object is the most convenient at the moment.

## Summary of changes

If `s3_prefix` was provided to `fast_import` call, any job run puts a
status object file into `{s3_prefix}/status/fast_import` with contents
`{"done": true}` or `{"done": false, "error": "..."}`. Added a test as
well
This commit is contained in:
Gleb Novikov
2025-03-20 15:23:35 +00:00
committed by GitHub
parent 3da70abfa5
commit 2065074559
3 changed files with 275 additions and 95 deletions

View File

@@ -31,6 +31,7 @@ use camino::{Utf8Path, Utf8PathBuf};
use clap::{Parser, Subcommand};
use compute_tools::extension_server::{PostgresMajorVersion, get_pg_version};
use nix::unistd::Pid;
use std::ops::Not;
use tracing::{Instrument, error, info, info_span, warn};
use utils::fs_ext::is_directory_empty;
@@ -44,7 +45,7 @@ mod s3_uri;
const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600);
const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300);
#[derive(Subcommand, Debug)]
#[derive(Subcommand, Debug, Clone, serde::Serialize)]
enum Command {
/// Runs local postgres (neon binary), restores into it,
/// uploads pgdata to s3 to be consumed by pageservers
@@ -84,6 +85,15 @@ enum Command {
},
}
impl Command {
fn as_str(&self) -> &'static str {
match self {
Command::Pgdata { .. } => "pgdata",
Command::DumpRestore { .. } => "dump-restore",
}
}
}
#[derive(clap::Parser)]
struct Args {
#[clap(long, env = "NEON_IMPORTER_WORKDIR")]
@@ -437,7 +447,7 @@ async fn run_dump_restore(
#[allow(clippy::too_many_arguments)]
async fn cmd_pgdata(
s3_client: Option<aws_sdk_s3::Client>,
s3_client: Option<&aws_sdk_s3::Client>,
kms_client: Option<aws_sdk_kms::Client>,
maybe_s3_prefix: Option<s3_uri::S3Uri>,
maybe_spec: Option<Spec>,
@@ -506,14 +516,14 @@ async fn cmd_pgdata(
if let Some(s3_prefix) = maybe_s3_prefix {
info!("upload pgdata");
aws_s3_sync::upload_dir_recursive(
s3_client.as_ref().unwrap(),
s3_client.unwrap(),
Utf8Path::new(&pgdata_dir),
&s3_prefix.append("/pgdata/"),
)
.await
.context("sync dump directory to destination")?;
info!("write status");
info!("write pgdata status to s3");
{
let status_dir = workdir.join("status");
std::fs::create_dir(&status_dir).context("create status directory")?;
@@ -550,13 +560,15 @@ async fn cmd_dumprestore(
&key_id,
spec.source_connstring_ciphertext_base64,
)
.await?;
.await
.context("decrypt source connection string")?;
let dest = if let Some(dest_ciphertext) =
spec.destination_connstring_ciphertext_base64
{
decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext)
.await?
.await
.context("decrypt destination connection string")?
} else {
bail!(
"destination connection string must be provided in spec for dump_restore command"
@@ -601,7 +613,18 @@ pub(crate) async fn main() -> anyhow::Result<()> {
// Initialize AWS clients only if s3_prefix is specified
let (s3_client, kms_client) = if args.s3_prefix.is_some() {
let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
// Create AWS config with enhanced retry settings
let config = aws_config::defaults(BehaviorVersion::v2024_03_28())
.retry_config(
aws_config::retry::RetryConfig::standard()
.with_max_attempts(5) // Retry up to 5 times
.with_initial_backoff(std::time::Duration::from_millis(200)) // Start with 200ms delay
.with_max_backoff(std::time::Duration::from_secs(5)), // Cap at 5 seconds
)
.load()
.await;
// Create clients from the config with enhanced retry settings
let s3_client = aws_sdk_s3::Client::new(&config);
let kms = aws_sdk_kms::Client::new(&config);
(Some(s3_client), Some(kms))
@@ -609,79 +632,108 @@ pub(crate) async fn main() -> anyhow::Result<()> {
(None, None)
};
let spec: Option<Spec> = if let Some(s3_prefix) = &args.s3_prefix {
let spec_key = s3_prefix.append("/spec.json");
let object = s3_client
.as_ref()
.unwrap()
.get_object()
.bucket(&spec_key.bucket)
.key(spec_key.key)
.send()
.await
.context("get spec from s3")?
.body
.collect()
.await
.context("download spec body")?;
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
} else {
None
};
match tokio::fs::create_dir(&args.working_directory).await {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if !is_directory_empty(&args.working_directory)
// Capture everything from spec assignment onwards to handle errors
let res = async {
let spec: Option<Spec> = if let Some(s3_prefix) = &args.s3_prefix {
let spec_key = s3_prefix.append("/spec.json");
let object = s3_client
.as_ref()
.unwrap()
.get_object()
.bucket(&spec_key.bucket)
.key(spec_key.key)
.send()
.await
.context("check if working directory is empty")?
{
bail!("working directory is not empty");
} else {
// ok
}
}
Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
}
.context("get spec from s3")?
.body
.collect()
.await
.context("download spec body")?;
serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
} else {
None
};
match args.command {
Command::Pgdata {
source_connection_string,
interactive,
pg_port,
num_cpus,
memory_mb,
} => {
cmd_pgdata(
s3_client,
kms_client,
args.s3_prefix,
spec,
match tokio::fs::create_dir(&args.working_directory).await {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if !is_directory_empty(&args.working_directory)
.await
.context("check if working directory is empty")?
{
bail!("working directory is not empty");
} else {
// ok
}
}
Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
}
match args.command.clone() {
Command::Pgdata {
source_connection_string,
interactive,
pg_port,
args.working_directory,
args.pg_bin_dir,
args.pg_lib_dir,
num_cpus,
memory_mb,
)
.await?;
}
Command::DumpRestore {
source_connection_string,
destination_connection_string,
} => {
cmd_dumprestore(
kms_client,
spec,
} => {
cmd_pgdata(
s3_client.as_ref(),
kms_client,
args.s3_prefix.clone(),
spec,
source_connection_string,
interactive,
pg_port,
args.working_directory.clone(),
args.pg_bin_dir,
args.pg_lib_dir,
num_cpus,
memory_mb,
)
.await
}
Command::DumpRestore {
source_connection_string,
destination_connection_string,
args.working_directory,
args.pg_bin_dir,
args.pg_lib_dir,
} => {
cmd_dumprestore(
kms_client,
spec,
source_connection_string,
destination_connection_string,
args.working_directory.clone(),
args.pg_bin_dir,
args.pg_lib_dir,
)
.await
}
}
}
.await;
if let Some(s3_prefix) = args.s3_prefix {
info!("write job status to s3");
{
let status_dir = args.working_directory.join("status");
if std::fs::exists(&status_dir)?.not() {
std::fs::create_dir(&status_dir).context("create status directory")?;
}
let status_file = status_dir.join("fast_import");
let res_obj = match res {
Ok(_) => serde_json::json!({"command": args.command.as_str(), "done": true}),
Err(err) => {
serde_json::json!({"command": args.command.as_str(), "done": false, "error": err.to_string()})
}
};
std::fs::write(&status_file, res_obj.to_string()).context("write status file")?;
aws_s3_sync::upload_dir_recursive(
s3_client.as_ref().unwrap(),
&status_dir,
&s3_prefix.append("/status/"),
)
.await?;
.await
.context("sync status directory to destination")?;
}
}

View File

@@ -12,6 +12,7 @@ from _pytest.config import Config
from fixtures.log_helper import log
from fixtures.neon_cli import AbstractNeonCli
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import MockS3Server
class FastImport(AbstractNeonCli):
@@ -111,6 +112,18 @@ class FastImport(AbstractNeonCli):
self.cmd = self.raw_cli(args)
return self.cmd
def set_aws_creds(self, mock_s3_server: MockS3Server, extra_env: dict[str, str] | None = None):
if self.extra_env is None:
self.extra_env = {}
self.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
self.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
self.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
self.extra_env["AWS_REGION"] = mock_s3_server.region()
self.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
if extra_env is not None:
self.extra_env.update(extra_env)
def __enter__(self):
return self

View File

@@ -158,6 +158,7 @@ def test_pgdata_import_smoke(
statusdir = importbucket / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
#
# Do the import
@@ -439,16 +440,23 @@ def test_fast_import_with_pageserver_ingest(
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
# Run fast_import
if fast_import.extra_env is None:
fast_import.extra_env = {}
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
pg_port = port_distributor.get_port()
fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}")
pgdata_status_obj = mock_s3_client.get_object(Bucket=bucket, Key=f"{key_prefix}/status/pgdata")
pgdata_status = pgdata_status_obj["Body"].read().decode("utf-8")
assert json.loads(pgdata_status) == {"done": True}, f"got status: {pgdata_status}"
job_status_obj = mock_s3_client.get_object(
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
)
job_status = job_status_obj["Body"].read().decode("utf-8")
assert json.loads(job_status) == {
"command": "pgdata",
"done": True,
}, f"got status: {job_status}"
vanilla_pg.stop()
def validate_vanilla_equivalence(ep):
@@ -674,21 +682,27 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
).decode("utf-8"),
}
mock_s3_client.create_bucket(Bucket="test-bucket")
bucket = "test-bucket"
key_prefix = "test-prefix"
mock_s3_client.create_bucket(Bucket=bucket)
mock_s3_client.put_object(
Bucket="test-bucket", Key="test-prefix/spec.json", Body=json.dumps(spec)
Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)
)
# Run fast_import
if fast_import.extra_env is None:
fast_import.extra_env = {}
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
fast_import.run_dump_restore(s3prefix="s3://test-bucket/test-prefix")
fast_import.set_aws_creds(
mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"}
)
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
job_status_obj = mock_s3_client.get_object(
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
)
job_status = job_status_obj["Body"].read().decode("utf-8")
assert json.loads(job_status) == {
"done": True,
"command": "dump-restore",
}, f"got status: {job_status}"
vanilla_pg.stop()
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
@@ -696,9 +710,110 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
assert res[0][0] == 10
# TODO: Maybe test with pageserver?
# 1. run whole neon env
# 2. create timeline with some s3 path???
# 3. run fast_import with s3 prefix
# 4. ??? mock http where pageserver will report progress
# 5. run compute on this timeline and check if data is there
def test_fast_import_restore_to_connstring_error_to_s3_bad_destination(
test_output_dir,
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
fast_import: FastImport,
pg_distrib_dir: Path,
pg_version: PgVersion,
mock_s3_server: MockS3Server,
mock_kms: KMSClient,
mock_s3_client: S3Client,
):
# Prepare KMS and S3
key_response = mock_kms.create_key(
Description="Test key",
KeyUsage="ENCRYPT_DECRYPT",
Origin="AWS_KMS",
)
key_id = key_response["KeyMetadata"]["KeyId"]
def encrypt(x: str) -> EncryptResponseTypeDef:
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
# Start source postgres and ingest data
vanilla_pg.start()
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
# Encrypt connstrings and put spec into S3
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
destination_connstring_encrypted = encrypt("postgres://random:connection@string:5432/neondb")
spec = {
"encryption_secret": {"KMS": {"key_id": key_id}},
"source_connstring_ciphertext_base64": base64.b64encode(
source_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
"destination_connstring_ciphertext_base64": base64.b64encode(
destination_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
}
bucket = "test-bucket"
key_prefix = "test-prefix"
mock_s3_client.create_bucket(Bucket=bucket)
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
# Run fast_import
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
job_status_obj = mock_s3_client.get_object(
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
)
job_status = job_status_obj["Body"].read().decode("utf-8")
assert json.loads(job_status) == {
"command": "dump-restore",
"done": False,
"error": "pg_restore failed",
}, f"got status: {job_status}"
vanilla_pg.stop()
def test_fast_import_restore_to_connstring_error_to_s3_kms_error(
test_output_dir,
port_distributor: PortDistributor,
fast_import: FastImport,
pg_distrib_dir: Path,
pg_version: PgVersion,
mock_s3_server: MockS3Server,
mock_kms: KMSClient,
mock_s3_client: S3Client,
):
# Prepare KMS and S3
key_response = mock_kms.create_key(
Description="Test key",
KeyUsage="ENCRYPT_DECRYPT",
Origin="AWS_KMS",
)
key_id = key_response["KeyMetadata"]["KeyId"]
def encrypt(x: str) -> EncryptResponseTypeDef:
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
# Encrypt connstrings and put spec into S3
spec = {
"encryption_secret": {"KMS": {"key_id": key_id}},
"source_connstring_ciphertext_base64": base64.b64encode(b"invalid encrypted string").decode(
"utf-8"
),
}
bucket = "test-bucket"
key_prefix = "test-prefix"
mock_s3_client.create_bucket(Bucket=bucket)
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))
# Run fast_import
fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"})
fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}")
job_status_obj = mock_s3_client.get_object(
Bucket=bucket, Key=f"{key_prefix}/status/fast_import"
)
job_status = job_status_obj["Body"].read().decode("utf-8")
assert json.loads(job_status) == {
"command": "dump-restore",
"done": False,
"error": "decrypt source connection string",
}, f"got status: {job_status}"