Run each migration in its own transaction

Previously, every migration was run in the same transaction. This
is preparatory work for fixing CVE-2024-4317.
This commit is contained in:
Tristan Partin
2024-07-12 13:38:51 -05:00
committed by Tristan Partin
parent b5ab055526
commit ba17025a57
3 changed files with 24 additions and 35 deletions

View File

@@ -9,6 +9,9 @@ pub(crate) struct MigrationRunner<'m> {
impl<'m> MigrationRunner<'m> {
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
Self { client, migrations }
}
@@ -22,11 +25,8 @@ impl<'m> MigrationRunner<'m> {
Ok(row.get::<&str, i64>("id"))
}
fn update_migration_id(&mut self) -> Result<()> {
let setval = format!(
"UPDATE neon_migration.migration_id SET id={}",
self.migrations.len()
);
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id);
self.client
.simple_query(&setval)
@@ -57,14 +57,7 @@ impl<'m> MigrationRunner<'m> {
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_migrations()?;
let mut current_migration: usize = self.get_migration_id()? as usize;
let starting_migration_id = current_migration;
let query = "BEGIN";
self.client
.simple_query(query)
.context("run_migrations begin")?;
let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
macro_rules! migration_id {
($cm:expr) => {
@@ -83,29 +76,30 @@ impl<'m> MigrationRunner<'m> {
migration
);
self.client
.simple_query("BEGIN")
.context("begin migration")?;
self.client.simple_query(migration).with_context(|| {
format!(
"run_migration migration id={}",
"run_migrations migration id={}",
migration_id!(current_migration)
)
})?;
// Migration IDs start at 1
self.update_migration_id(migration_id!(current_migration))?;
self.client
.simple_query("COMMIT")
.context("commit migration")?;
info!("Finished migration id={}", migration_id!(current_migration));
}
current_migration += 1;
}
self.update_migration_id()?;
let query = "COMMIT";
self.client
.simple_query(query)
.context("run_migrations commit")?;
info!(
"Ran {} migrations",
(self.migrations.len() - starting_migration_id)
);
Ok(())
}
}

View File

@@ -3798,13 +3798,13 @@ class Endpoint(PgProtocol, LogUtils):
json.dump(dict(data_dict, **kwargs), file, indent=4)
# Please note: Migrations only run if pg_skip_catalog_updates is false
def wait_for_migrations(self):
def wait_for_migrations(self, num_migrations: int = 10):
with self.cursor() as cur:
def check_migrations_done():
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()[0][0]
assert migration_id != 0
migration_id: int = cur.fetchall()[0][0]
assert migration_id >= num_migrations
wait_until(20, 0.5, check_migrations_done)

View File

@@ -11,17 +11,14 @@ def test_migrations(neon_simple_env: NeonEnv):
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
endpoint.wait_for_migrations()
num_migrations = 10
endpoint.wait_for_migrations(num_migrations=num_migrations)
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations
endpoint.assert_log_contains(f"INFO handle_migrations: Ran {num_migrations} migrations")
endpoint.stop()
endpoint.start()
# We don't have a good way of knowing that the migrations code path finished executing
@@ -31,5 +28,3 @@ def test_migrations(neon_simple_env: NeonEnv):
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations
endpoint.assert_log_contains("INFO handle_migrations: Ran 0 migrations")