mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
Run each migration in its own transaction
Previously, every migration was run in the same transaction. This is preparatory work for fixing CVE-2024-4317.
This commit is contained in:
committed by
Tristan Partin
parent
b5ab055526
commit
ba17025a57
@@ -9,6 +9,9 @@ pub(crate) struct MigrationRunner<'m> {
|
||||
|
||||
impl<'m> MigrationRunner<'m> {
|
||||
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
|
||||
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
|
||||
assert!(migrations.len() + 1 < i64::MAX as usize);
|
||||
|
||||
Self { client, migrations }
|
||||
}
|
||||
|
||||
@@ -22,11 +25,8 @@ impl<'m> MigrationRunner<'m> {
|
||||
Ok(row.get::<&str, i64>("id"))
|
||||
}
|
||||
|
||||
fn update_migration_id(&mut self) -> Result<()> {
|
||||
let setval = format!(
|
||||
"UPDATE neon_migration.migration_id SET id={}",
|
||||
self.migrations.len()
|
||||
);
|
||||
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
|
||||
let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id);
|
||||
|
||||
self.client
|
||||
.simple_query(&setval)
|
||||
@@ -57,14 +57,7 @@ impl<'m> MigrationRunner<'m> {
|
||||
pub fn run_migrations(mut self) -> Result<()> {
|
||||
self.prepare_migrations()?;
|
||||
|
||||
let mut current_migration: usize = self.get_migration_id()? as usize;
|
||||
let starting_migration_id = current_migration;
|
||||
|
||||
let query = "BEGIN";
|
||||
self.client
|
||||
.simple_query(query)
|
||||
.context("run_migrations begin")?;
|
||||
|
||||
let mut current_migration = self.get_migration_id()? as usize;
|
||||
while current_migration < self.migrations.len() {
|
||||
macro_rules! migration_id {
|
||||
($cm:expr) => {
|
||||
@@ -83,29 +76,30 @@ impl<'m> MigrationRunner<'m> {
|
||||
migration
|
||||
);
|
||||
|
||||
self.client
|
||||
.simple_query("BEGIN")
|
||||
.context("begin migration")?;
|
||||
|
||||
self.client.simple_query(migration).with_context(|| {
|
||||
format!(
|
||||
"run_migration migration id={}",
|
||||
"run_migrations migration id={}",
|
||||
migration_id!(current_migration)
|
||||
)
|
||||
})?;
|
||||
|
||||
// Migration IDs start at 1
|
||||
self.update_migration_id(migration_id!(current_migration))?;
|
||||
|
||||
self.client
|
||||
.simple_query("COMMIT")
|
||||
.context("commit migration")?;
|
||||
|
||||
info!("Finished migration id={}", migration_id!(current_migration));
|
||||
}
|
||||
|
||||
current_migration += 1;
|
||||
}
|
||||
|
||||
self.update_migration_id()?;
|
||||
|
||||
let query = "COMMIT";
|
||||
self.client
|
||||
.simple_query(query)
|
||||
.context("run_migrations commit")?;
|
||||
|
||||
info!(
|
||||
"Ran {} migrations",
|
||||
(self.migrations.len() - starting_migration_id)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3798,13 +3798,13 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
json.dump(dict(data_dict, **kwargs), file, indent=4)
|
||||
|
||||
# Please note: Migrations only run if pg_skip_catalog_updates is false
|
||||
def wait_for_migrations(self):
|
||||
def wait_for_migrations(self, num_migrations: int = 10):
|
||||
with self.cursor() as cur:
|
||||
|
||||
def check_migrations_done():
|
||||
cur.execute("SELECT id FROM neon_migration.migration_id")
|
||||
migration_id = cur.fetchall()[0][0]
|
||||
assert migration_id != 0
|
||||
migration_id: int = cur.fetchall()[0][0]
|
||||
assert migration_id >= num_migrations
|
||||
|
||||
wait_until(20, 0.5, check_migrations_done)
|
||||
|
||||
|
||||
@@ -11,17 +11,14 @@ def test_migrations(neon_simple_env: NeonEnv):
|
||||
endpoint.respec(skip_pg_catalog_updates=False)
|
||||
endpoint.start()
|
||||
|
||||
endpoint.wait_for_migrations()
|
||||
|
||||
num_migrations = 10
|
||||
endpoint.wait_for_migrations(num_migrations=num_migrations)
|
||||
|
||||
with endpoint.cursor() as cur:
|
||||
cur.execute("SELECT id FROM neon_migration.migration_id")
|
||||
migration_id = cur.fetchall()
|
||||
assert migration_id[0][0] == num_migrations
|
||||
|
||||
endpoint.assert_log_contains(f"INFO handle_migrations: Ran {num_migrations} migrations")
|
||||
|
||||
endpoint.stop()
|
||||
endpoint.start()
|
||||
# We don't have a good way of knowing that the migrations code path finished executing
|
||||
@@ -31,5 +28,3 @@ def test_migrations(neon_simple_env: NeonEnv):
|
||||
cur.execute("SELECT id FROM neon_migration.migration_id")
|
||||
migration_id = cur.fetchall()
|
||||
assert migration_id[0][0] == num_migrations
|
||||
|
||||
endpoint.assert_log_contains("INFO handle_migrations: Ran 0 migrations")
|
||||
|
||||
Reference in New Issue
Block a user