diff --git a/pgxn/neon/control_plane_connector.c b/pgxn/neon/control_plane_connector.c index b47b22cd20..59096a1bc8 100644 --- a/pgxn/neon/control_plane_connector.c +++ b/pgxn/neon/control_plane_connector.c @@ -428,6 +428,8 @@ MergeTable() hash_seq_init(&status, old_table->role_table); while ((entry = hash_seq_search(&status)) != NULL) { + RoleEntry * old; + bool found_old = false; RoleEntry *to_write = hash_search( CurrentDdlTable->role_table, entry->name, @@ -435,30 +437,23 @@ MergeTable() NULL); to_write->type = entry->type; - if (entry->password) - to_write->password = entry->password; + to_write->password = entry->password; strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); - if (entry->old_name[0] != '\0') - { - bool found_old = false; - RoleEntry *old = hash_search( - CurrentDdlTable->role_table, - entry->old_name, - HASH_FIND, - &found_old); + if (entry->old_name[0] == '\0') + continue; - if (found_old) - { - if (old->old_name[0] != '\0') - strlcpy(to_write->old_name, old->old_name, NAMEDATALEN); - else - strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); - hash_search(CurrentDdlTable->role_table, - entry->old_name, - HASH_REMOVE, - NULL); - } - } + old = hash_search( + CurrentDdlTable->role_table, + entry->old_name, + HASH_FIND, + &found_old); + if (!found_old) + continue; + strlcpy(to_write->old_name, old->old_name, NAMEDATALEN); + hash_search(CurrentDdlTable->role_table, + entry->old_name, + HASH_REMOVE, + NULL); } hash_destroy(old_table->role_table); } diff --git a/test_runner/regress/test_ddl_forwarding.py b/test_runner/regress/test_ddl_forwarding.py index de44bbcbc8..b10e38885e 100644 --- a/test_runner/regress/test_ddl_forwarding.py +++ b/test_runner/regress/test_ddl_forwarding.py @@ -60,14 +60,12 @@ def ddl_forward_handler( if request.json is None: log.info("Received invalid JSON") return Response(status=400) - json = request.json + json: dict[str, list[str]] = request.json # Handle roles first - if "roles" in json: - for operation in json["roles"]: - handle_role(dbs, roles, operation) - if "dbs" in json: - for operation in json["dbs"]: - handle_db(dbs, roles, operation) + for operation in json.get("roles", []): + handle_role(dbs, roles, operation) + for operation in json.get("dbs", []): + handle_db(dbs, roles, operation) return Response(status=200) @@ -207,6 +205,23 @@ def test_ddl_forwarding(ddl: DdlForwardingContext): ddl.wait() assert ddl.roles == {} + cur.execute("CREATE ROLE bork WITH PASSWORD 'newyork'") + cur.execute("BEGIN") + cur.execute("SAVEPOINT point") + cur.execute("DROP ROLE bork") + cur.execute("COMMIT") + ddl.wait() + assert ddl.roles == {} + + cur.execute("CREATE ROLE bork WITH PASSWORD 'oldyork'") + cur.execute("BEGIN") + cur.execute("SAVEPOINT point") + cur.execute("ALTER ROLE bork PASSWORD NULL") + cur.execute("COMMIT") + cur.execute("DROP ROLE bork") + ddl.wait() + assert ddl.roles == {} + cur.execute("CREATE ROLE bork WITH PASSWORD 'dork'") cur.execute("CREATE DATABASE stork WITH OWNER=bork") cur.execute("ALTER ROLE bork RENAME TO cork")