From ba17025a57bc4916b3efeb0fd068f2ada7f668a8 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Fri, 12 Jul 2024 13:38:51 -0500 Subject: [PATCH] Run each migration in its own transaction Previously, every migration was run in the same transaction. This is preparatory work for fixing CVE-2024-4317. --- compute_tools/src/migration.rs | 46 +++++++++++--------------- test_runner/fixtures/neon_fixtures.py | 6 ++-- test_runner/regress/test_migrations.py | 7 +--- 3 files changed, 24 insertions(+), 35 deletions(-) diff --git a/compute_tools/src/migration.rs b/compute_tools/src/migration.rs index 241ccd4100..22ab145eda 100644 --- a/compute_tools/src/migration.rs +++ b/compute_tools/src/migration.rs @@ -9,6 +9,9 @@ pub(crate) struct MigrationRunner<'m> { impl<'m> MigrationRunner<'m> { pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self { + // The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64 + assert!(migrations.len() + 1 < i64::MAX as usize); + Self { client, migrations } } @@ -22,11 +25,8 @@ impl<'m> MigrationRunner<'m> { Ok(row.get::<&str, i64>("id")) } - fn update_migration_id(&mut self) -> Result<()> { - let setval = format!( - "UPDATE neon_migration.migration_id SET id={}", - self.migrations.len() - ); + fn update_migration_id(&mut self, migration_id: i64) -> Result<()> { + let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id); self.client .simple_query(&setval) @@ -57,14 +57,7 @@ impl<'m> MigrationRunner<'m> { pub fn run_migrations(mut self) -> Result<()> { self.prepare_migrations()?; - let mut current_migration: usize = self.get_migration_id()? as usize; - let starting_migration_id = current_migration; - - let query = "BEGIN"; - self.client - .simple_query(query) - .context("run_migrations begin")?; - + let mut current_migration = self.get_migration_id()? as usize; while current_migration < self.migrations.len() { macro_rules! migration_id { ($cm:expr) => { @@ -83,29 +76,30 @@ impl<'m> MigrationRunner<'m> { migration ); + self.client + .simple_query("BEGIN") + .context("begin migration")?; + self.client.simple_query(migration).with_context(|| { format!( - "run_migration migration id={}", + "run_migrations migration id={}", migration_id!(current_migration) ) })?; + + // Migration IDs start at 1 + self.update_migration_id(migration_id!(current_migration))?; + + self.client + .simple_query("COMMIT") + .context("commit migration")?; + + info!("Finished migration id={}", migration_id!(current_migration)); } current_migration += 1; } - self.update_migration_id()?; - - let query = "COMMIT"; - self.client - .simple_query(query) - .context("run_migrations commit")?; - - info!( - "Ran {} migrations", - (self.migrations.len() - starting_migration_id) - ); - Ok(()) } } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 625e9096f5..4766b72516 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3798,13 +3798,13 @@ class Endpoint(PgProtocol, LogUtils): json.dump(dict(data_dict, **kwargs), file, indent=4) # Please note: Migrations only run if pg_skip_catalog_updates is false - def wait_for_migrations(self): + def wait_for_migrations(self, num_migrations: int = 10): with self.cursor() as cur: def check_migrations_done(): cur.execute("SELECT id FROM neon_migration.migration_id") - migration_id = cur.fetchall()[0][0] - assert migration_id != 0 + migration_id: int = cur.fetchall()[0][0] + assert migration_id >= num_migrations wait_until(20, 0.5, check_migrations_done) diff --git a/test_runner/regress/test_migrations.py b/test_runner/regress/test_migrations.py index 91bd3ea50c..880dead4e8 100644 --- a/test_runner/regress/test_migrations.py +++ b/test_runner/regress/test_migrations.py @@ -11,17 +11,14 @@ def test_migrations(neon_simple_env: NeonEnv): endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() - endpoint.wait_for_migrations() - num_migrations = 10 + endpoint.wait_for_migrations(num_migrations=num_migrations) with endpoint.cursor() as cur: cur.execute("SELECT id FROM neon_migration.migration_id") migration_id = cur.fetchall() assert migration_id[0][0] == num_migrations - endpoint.assert_log_contains(f"INFO handle_migrations: Ran {num_migrations} migrations") - endpoint.stop() endpoint.start() # We don't have a good way of knowing that the migrations code path finished executing @@ -31,5 +28,3 @@ def test_migrations(neon_simple_env: NeonEnv): cur.execute("SELECT id FROM neon_migration.migration_id") migration_id = cur.fetchall() assert migration_id[0][0] == num_migrations - - endpoint.assert_log_contains("INFO handle_migrations: Ran 0 migrations")