diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 6dcf988191..bcc02398a1 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -661,6 +661,9 @@ jobs: project: nrdv0s4kcs push: true tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}} + build-args: | + GIT_VERSION=${{ github.sha }} + REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com compute-tools-image: runs-on: [ self-hosted, gen3, large ] diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index 3a3dee8a8a..de8a904c02 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -632,6 +632,7 @@ RUN apt update && \ libxml2 \ libxslt1.1 \ libzstd1 \ + libcurl4-openssl-dev \ procps && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8 diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index da5ad00da6..a7746629a8 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -362,6 +362,8 @@ impl ComputeNode { }; // Proceed with post-startup configuration. Note, that order of operations is important. + // Disable DDL forwarding because control plane already knows about these roles/databases. + client.simple_query("SET neon.forward_ddl = false")?; let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec; handle_roles(spec, &mut client)?; handle_databases(spec, &mut client)?; @@ -403,7 +405,9 @@ impl ComputeNode { self.pg_reload_conf(&mut client)?; // Proceed with post-startup configuration. Note, that order of operations is important. + // Disable DDL forwarding because control plane already knows about these roles/databases. if spec.mode == ComputeMode::Primary { + client.simple_query("SET neon.forward_ddl = false")?; handle_roles(&spec, &mut client)?; handle_databases(&spec, &mut client)?; handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?; diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 40dbea6907..ed00485d5a 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -121,9 +121,8 @@ impl RoleExt for Role { /// string of arguments. fn to_pg_options(&self) -> String { // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane. - // For now, we do not use generic `options` for roles. Once used, add - // `self.options.as_pg_options()` somewhere here. - let mut params: String = "LOGIN".to_string(); + let mut params: String = self.options.as_pg_options(); + params.push_str(" LOGIN"); if let Some(pass) = &self.encrypted_password { // Some time ago we supported only md5 and treated all encrypted_password as md5. diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index bf3c407202..a2a19ae0da 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -62,7 +62,7 @@ fn do_control_plane_request( } } -/// Request spec from the control-plane by compute_id. If `NEON_CONSOLE_JWT` +/// Request spec from the control-plane by compute_id. If `NEON_CONTROL_PLANE_TOKEN` /// env variable is set, it will be used for authorization. pub fn get_spec_from_control_plane( base_uri: &str, diff --git a/compute_tools/tests/pg_helpers_tests.rs b/compute_tools/tests/pg_helpers_tests.rs index a63ee038c7..265556d3b9 100644 --- a/compute_tools/tests/pg_helpers_tests.rs +++ b/compute_tools/tests/pg_helpers_tests.rs @@ -16,7 +16,7 @@ mod pg_helpers_tests { ); assert_eq!( spec.cluster.roles.first().unwrap().to_pg_options(), - "LOGIN PASSWORD 'md56b1d16b78004bbd51fa06af9eda75972'" + " LOGIN PASSWORD 'md56b1d16b78004bbd51fa06af9eda75972'" ); } diff --git a/pgxn/neon/Makefile b/pgxn/neon/Makefile index ec377dbb1e..1948023472 100644 --- a/pgxn/neon/Makefile +++ b/pgxn/neon/Makefile @@ -11,10 +11,12 @@ OBJS = \ pagestore_smgr.o \ relsize_cache.o \ walproposer.o \ - walproposer_utils.o + walproposer_utils.o \ + control_plane_connector.o PG_CPPFLAGS = -I$(libpq_srcdir) SHLIB_LINK_INTERNAL = $(libpq) +SHLIB_LINK = -lcurl EXTENSION = neon DATA = neon--1.0.sql diff --git a/pgxn/neon/control_plane_connector.c b/pgxn/neon/control_plane_connector.c new file mode 100644 index 0000000000..82e4af4b4a --- /dev/null +++ b/pgxn/neon/control_plane_connector.c @@ -0,0 +1,830 @@ +/*------------------------------------------------------------------------- + * + * control_plane_connector.c + * Captures updates to roles/databases using ProcessUtility_hook and + * sends them to the control ProcessUtility_hook. The changes are sent + * via HTTP to the URL specified by the GUC neon.console_url when the + * transaction commits. Forwarding may be disabled temporarily by + * setting neon.forward_ddl to false. + * + * Currently, the transaction may abort AFTER + * changes have already been forwarded, and that case is not handled. + * Subtransactions are handled using a stack of hash tables, which + * accumulate changes. On subtransaction commit, the top of the stack + * is merged with the table below it. + * + * IDENTIFICATION + * contrib/neon/control_plane_connector.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" +#include "tcop/pquery.h" +#include "tcop/utility.h" +#include "access/xact.h" +#include "utils/hsearch.h" +#include "utils/memutils.h" +#include "commands/defrem.h" +#include "miscadmin.h" +#include "utils/acl.h" +#include "fmgr.h" +#include "utils/guc.h" +#include "port.h" +#include +#include "utils/jsonb.h" + +static ProcessUtility_hook_type PreviousProcessUtilityHook = NULL; + +/* GUCs */ +static char *ConsoleURL = NULL; +static bool ForwardDDL = true; + +/* Curl structures for sending the HTTP requests */ +static CURL * CurlHandle; +static struct curl_slist *ContentHeader = NULL; + +/* + * CURL docs say that this buffer must exist until we call curl_easy_cleanup + * (which we never do), so we make this a static + */ +static char CurlErrorBuf[CURL_ERROR_SIZE]; + +typedef enum +{ + Op_Set, /* An upsert: Either a creation or an alter */ + Op_Delete, +} OpType; + +typedef struct +{ + char name[NAMEDATALEN]; + Oid owner; + char old_name[NAMEDATALEN]; + OpType type; +} DbEntry; + +typedef struct +{ + char name[NAMEDATALEN]; + char old_name[NAMEDATALEN]; + const char *password; + OpType type; +} RoleEntry; + +/* + * We keep one of these for each subtransaction in a stack. When a subtransaction + * commits, we merge the top of the stack into the table below it. It is allocated in the + * subtransaction's context. + */ +typedef struct DdlHashTable +{ + struct DdlHashTable *prev_table; + HTAB *db_table; + HTAB *role_table; +} DdlHashTable; + +static DdlHashTable RootTable; +static DdlHashTable * CurrentDdlTable = &RootTable; + +static void +PushKeyValue(JsonbParseState **state, char *key, char *value) +{ + JsonbValue k, + v; + + k.type = jbvString; + k.val.string.len = strlen(key); + k.val.string.val = key; + v.type = jbvString; + v.val.string.len = strlen(value); + v.val.string.val = value; + pushJsonbValue(state, WJB_KEY, &k); + pushJsonbValue(state, WJB_VALUE, &v); +} + +static char * +ConstructDeltaMessage() +{ + JsonbParseState *state = NULL; + + pushJsonbValue(&state, WJB_BEGIN_OBJECT, NULL); + if (RootTable.db_table) + { + JsonbValue dbs; + + dbs.type = jbvString; + dbs.val.string.val = "dbs"; + dbs.val.string.len = strlen(dbs.val.string.val); + pushJsonbValue(&state, WJB_KEY, &dbs); + pushJsonbValue(&state, WJB_BEGIN_ARRAY, NULL); + + HASH_SEQ_STATUS status; + DbEntry *entry; + + hash_seq_init(&status, RootTable.db_table); + while ((entry = hash_seq_search(&status)) != NULL) + { + pushJsonbValue(&state, WJB_BEGIN_OBJECT, NULL); + PushKeyValue(&state, "op", entry->type == Op_Set ? "set" : "del"); + PushKeyValue(&state, "name", entry->name); + if (entry->owner != InvalidOid) + { + PushKeyValue(&state, "owner", GetUserNameFromId(entry->owner, false)); + } + if (entry->old_name[0] != '\0') + { + PushKeyValue(&state, "old_name", entry->old_name); + } + pushJsonbValue(&state, WJB_END_OBJECT, NULL); + } + pushJsonbValue(&state, WJB_END_ARRAY, NULL); + } + + if (RootTable.role_table) + { + JsonbValue roles; + + roles.type = jbvString; + roles.val.string.val = "roles"; + roles.val.string.len = strlen(roles.val.string.val); + pushJsonbValue(&state, WJB_KEY, &roles); + pushJsonbValue(&state, WJB_BEGIN_ARRAY, NULL); + + HASH_SEQ_STATUS status; + RoleEntry *entry; + + hash_seq_init(&status, RootTable.role_table); + while ((entry = hash_seq_search(&status)) != NULL) + { + pushJsonbValue(&state, WJB_BEGIN_OBJECT, NULL); + PushKeyValue(&state, "op", entry->type == Op_Set ? "set" : "del"); + PushKeyValue(&state, "name", entry->name); + if (entry->password) + { + PushKeyValue(&state, "password", (char *) entry->password); + } + if (entry->old_name[0] != '\0') + { + PushKeyValue(&state, "old_name", entry->old_name); + } + pushJsonbValue(&state, WJB_END_OBJECT, NULL); + } + pushJsonbValue(&state, WJB_END_ARRAY, NULL); + } + JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL); + Jsonb *jsonb = JsonbValueToJsonb(result); + + return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ ); +} + +#define ERROR_SIZE 1024 + +typedef struct +{ + char str[ERROR_SIZE]; + size_t size; +} ErrorString; + +static size_t +ErrorWriteCallback(char *ptr, size_t size, size_t nmemb, void *userdata) +{ + /* Docs say size is always 1 */ + ErrorString *str = userdata; + + size_t to_write = nmemb; + + /* +1 for null terminator */ + if (str->size + nmemb + 1 >= ERROR_SIZE) + to_write = ERROR_SIZE - str->size - 1; + + /* Ignore everyrthing past the first ERROR_SIZE bytes */ + if (to_write == 0) + return nmemb; + memcpy(str->str + str->size, ptr, to_write); + str->size += to_write; + str->str[str->size] = '\0'; + return nmemb; +} + +static void +SendDeltasToControlPlane() +{ + if (!RootTable.db_table && !RootTable.role_table) + return; + if (!ConsoleURL) + { + elog(LOG, "ConsoleURL not set, skipping forwarding"); + return; + } + if (!ForwardDDL) + return; + + char *message = ConstructDeltaMessage(); + ErrorString str = {}; + + curl_easy_setopt(CurlHandle, CURLOPT_CUSTOMREQUEST, "PATCH"); + curl_easy_setopt(CurlHandle, CURLOPT_HTTPHEADER, ContentHeader); + curl_easy_setopt(CurlHandle, CURLOPT_POSTFIELDS, message); + curl_easy_setopt(CurlHandle, CURLOPT_URL, ConsoleURL); + curl_easy_setopt(CurlHandle, CURLOPT_ERRORBUFFER, CurlErrorBuf); + curl_easy_setopt(CurlHandle, CURLOPT_TIMEOUT, 3L /* seconds */ ); + curl_easy_setopt(CurlHandle, CURLOPT_WRITEDATA, &str); + curl_easy_setopt(CurlHandle, CURLOPT_WRITEFUNCTION, ErrorWriteCallback); + + const int num_retries = 5; + int curl_status; + + for (int i = 0; i < num_retries; i++) + { + if ((curl_status = curl_easy_perform(CurlHandle)) == 0) + break; + elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf); + pg_usleep(1000 * 1000); + } + if (curl_status != 0) + { + elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf); + } + else + { + long response_code; + + if (curl_easy_getinfo(CurlHandle, CURLINFO_RESPONSE_CODE, &response_code) != CURLE_UNKNOWN_OPTION) + { + bool error_exists = str.size != 0; + + if (response_code != 200) + { + if (error_exists) + { + elog(ERROR, + "Received HTTP code %ld from control plane: %s", + response_code, + str.str); + } + else + { + elog(ERROR, + "Received HTTP code %ld from control plane", + response_code); + } + } + } + } +} + +static void +InitDbTableIfNeeded() +{ + if (!CurrentDdlTable->db_table) + { + HASHCTL db_ctl = {}; + + db_ctl.keysize = NAMEDATALEN; + db_ctl.entrysize = sizeof(DbEntry); + db_ctl.hcxt = CurTransactionContext; + CurrentDdlTable->db_table = hash_create( + "Dbs Created", + 4, + &db_ctl, + HASH_ELEM | HASH_STRINGS | HASH_CONTEXT); + } +} + +static void +InitRoleTableIfNeeded() +{ + if (!CurrentDdlTable->role_table) + { + HASHCTL role_ctl = {}; + + role_ctl.keysize = NAMEDATALEN; + role_ctl.entrysize = sizeof(RoleEntry); + role_ctl.hcxt = CurTransactionContext; + CurrentDdlTable->role_table = hash_create( + "Roles Created", + 4, + &role_ctl, + HASH_ELEM | HASH_STRINGS | HASH_CONTEXT); + } +} + +static void +PushTable() +{ + DdlHashTable *new_table = MemoryContextAlloc(CurTransactionContext, sizeof(DdlHashTable)); + + new_table->prev_table = CurrentDdlTable; + new_table->role_table = NULL; + new_table->db_table = NULL; + CurrentDdlTable = new_table; +} + +static void +MergeTable() +{ + DdlHashTable *old_table = CurrentDdlTable; + + CurrentDdlTable = old_table->prev_table; + + if (old_table->db_table) + { + InitDbTableIfNeeded(); + DbEntry *entry; + HASH_SEQ_STATUS status; + + hash_seq_init(&status, old_table->db_table); + while ((entry = hash_seq_search(&status)) != NULL) + { + DbEntry *to_write = hash_search( + CurrentDdlTable->db_table, + entry->name, + HASH_ENTER, + NULL); + + to_write->type = entry->type; + if (entry->owner != InvalidOid) + to_write->owner = entry->owner; + strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); + if (entry->old_name[0] != '\0') + { + bool found_old = false; + DbEntry *old = hash_search( + CurrentDdlTable->db_table, + entry->old_name, + HASH_FIND, + &found_old); + + if (found_old) + { + if (old->old_name[0] != '\0') + strlcpy(to_write->old_name, old->old_name, NAMEDATALEN); + else + strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); + hash_search( + CurrentDdlTable->db_table, + entry->old_name, + HASH_REMOVE, + NULL); + } + } + } + hash_destroy(old_table->db_table); + } + + if (old_table->role_table) + { + InitRoleTableIfNeeded(); + RoleEntry *entry; + HASH_SEQ_STATUS status; + + hash_seq_init(&status, old_table->role_table); + while ((entry = hash_seq_search(&status)) != NULL) + { + RoleEntry *to_write = hash_search( + CurrentDdlTable->role_table, + entry->name, + HASH_ENTER, + NULL); + + to_write->type = entry->type; + if (entry->password) + to_write->password = entry->password; + strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); + if (entry->old_name[0] != '\0') + { + bool found_old = false; + RoleEntry *old = hash_search( + CurrentDdlTable->role_table, + entry->old_name, + HASH_FIND, + &found_old); + + if (found_old) + { + if (old->old_name[0] != '\0') + strlcpy(to_write->old_name, old->old_name, NAMEDATALEN); + else + strlcpy(to_write->old_name, entry->old_name, NAMEDATALEN); + hash_search(CurrentDdlTable->role_table, + entry->old_name, + HASH_REMOVE, + NULL); + } + } + } + hash_destroy(old_table->role_table); + } +} + +static void +PopTable() +{ + /* + * Current table gets freed because it is allocated in aborted + * subtransaction's memory context. + */ + CurrentDdlTable = CurrentDdlTable->prev_table; +} + +static void +NeonSubXactCallback( + SubXactEvent event, + SubTransactionId mySubid, + SubTransactionId parentSubid, + void *arg) +{ + switch (event) + { + case SUBXACT_EVENT_START_SUB: + return PushTable(); + case SUBXACT_EVENT_COMMIT_SUB: + return MergeTable(); + case SUBXACT_EVENT_ABORT_SUB: + return PopTable(); + default: + return; + } +} + +static void +NeonXactCallback(XactEvent event, void *arg) +{ + if (event == XACT_EVENT_PRE_COMMIT || event == XACT_EVENT_PARALLEL_PRE_COMMIT) + { + SendDeltasToControlPlane(); + } + RootTable.role_table = NULL; + RootTable.db_table = NULL; + Assert(CurrentDdlTable == &RootTable); +} + +static void +HandleCreateDb(CreatedbStmt *stmt) +{ + InitDbTableIfNeeded(); + DefElem *downer = NULL; + ListCell *option; + + foreach(option, stmt->options) + { + DefElem *defel = lfirst(option); + + if (strcmp(defel->defname, "owner") == 0) + downer = defel; + } + bool found = false; + DbEntry *entry = hash_search( + CurrentDdlTable->db_table, + stmt->dbname, + HASH_ENTER, + &found); + + if (!found) + memset(entry->old_name, 0, sizeof(entry->old_name)); + + entry->type = Op_Set; + if (downer && downer->arg) + entry->owner = get_role_oid(defGetString(downer), false); + else + entry->owner = GetUserId(); +} + +static void +HandleAlterOwner(AlterOwnerStmt *stmt) +{ + if (stmt->objectType != OBJECT_DATABASE) + return; + InitDbTableIfNeeded(); + const char *name = strVal(stmt->object); + bool found = false; + DbEntry *entry = hash_search( + CurrentDdlTable->db_table, + name, + HASH_ENTER, + &found); + + if (!found) + memset(entry->old_name, 0, sizeof(entry->old_name)); + + entry->owner = get_role_oid(get_rolespec_name(stmt->newowner), false); + entry->type = Op_Set; +} + +static void +HandleDbRename(RenameStmt *stmt) +{ + Assert(stmt->renameType == OBJECT_DATABASE); + InitDbTableIfNeeded(); + bool found = false; + DbEntry *entry = hash_search( + CurrentDdlTable->db_table, + stmt->subname, + HASH_FIND, + &found); + DbEntry *entry_for_new_name = hash_search( + CurrentDdlTable->db_table, + stmt->newname, + HASH_ENTER, + NULL); + + entry_for_new_name->type = Op_Set; + if (found) + { + if (entry->old_name[0] != '\0') + strlcpy(entry_for_new_name->old_name, entry->old_name, NAMEDATALEN); + else + strlcpy(entry_for_new_name->old_name, entry->name, NAMEDATALEN); + entry_for_new_name->owner = entry->owner; + hash_search( + CurrentDdlTable->db_table, + stmt->subname, + HASH_REMOVE, + NULL); + } + else + { + strlcpy(entry_for_new_name->old_name, stmt->subname, NAMEDATALEN); + entry_for_new_name->owner = InvalidOid; + } +} + +static void +HandleDropDb(DropdbStmt *stmt) +{ + InitDbTableIfNeeded(); + bool found = false; + DbEntry *entry = hash_search( + CurrentDdlTable->db_table, + stmt->dbname, + HASH_ENTER, + &found); + + entry->type = Op_Delete; + entry->owner = InvalidOid; + if (!found) + memset(entry->old_name, 0, sizeof(entry->old_name)); +} + +static void +HandleCreateRole(CreateRoleStmt *stmt) +{ + InitRoleTableIfNeeded(); + bool found = false; + RoleEntry *entry = hash_search( + CurrentDdlTable->role_table, + stmt->role, + HASH_ENTER, + &found); + DefElem *dpass = NULL; + ListCell *option; + + foreach(option, stmt->options) + { + DefElem *defel = lfirst(option); + + if (strcmp(defel->defname, "password") == 0) + dpass = defel; + } + if (!found) + memset(entry->old_name, 0, sizeof(entry->old_name)); + if (dpass && dpass->arg) + entry->password = MemoryContextStrdup(CurTransactionContext, strVal(dpass->arg)); + else + entry->password = NULL; + entry->type = Op_Set; +} + +static void +HandleAlterRole(AlterRoleStmt *stmt) +{ + InitRoleTableIfNeeded(); + DefElem *dpass = NULL; + ListCell *option; + + foreach(option, stmt->options) + { + DefElem *defel = lfirst(option); + + if (strcmp(defel->defname, "password") == 0) + dpass = defel; + } + /* We only care about updates to the password */ + if (!dpass) + return; + bool found = false; + RoleEntry *entry = hash_search( + CurrentDdlTable->role_table, + stmt->role->rolename, + HASH_ENTER, + &found); + + if (!found) + memset(entry->old_name, 0, sizeof(entry->old_name)); + if (dpass->arg) + entry->password = MemoryContextStrdup(CurTransactionContext, strVal(dpass->arg)); + else + entry->password = NULL; + entry->type = Op_Set; +} + +static void +HandleRoleRename(RenameStmt *stmt) +{ + InitRoleTableIfNeeded(); + Assert(stmt->renameType == OBJECT_ROLE); + bool found = false; + RoleEntry *entry = hash_search( + CurrentDdlTable->role_table, + stmt->subname, + HASH_FIND, + &found); + + RoleEntry *entry_for_new_name = hash_search( + CurrentDdlTable->role_table, + stmt->newname, + HASH_ENTER, + NULL); + + entry_for_new_name->type = Op_Set; + if (found) + { + if (entry->old_name[0] != '\0') + strlcpy(entry_for_new_name->old_name, entry->old_name, NAMEDATALEN); + else + strlcpy(entry_for_new_name->old_name, entry->name, NAMEDATALEN); + entry_for_new_name->password = entry->password; + hash_search( + CurrentDdlTable->role_table, + entry->name, + HASH_REMOVE, + NULL); + } + else + { + strlcpy(entry_for_new_name->old_name, stmt->subname, NAMEDATALEN); + entry_for_new_name->password = NULL; + } +} + +static void +HandleDropRole(DropRoleStmt *stmt) +{ + InitRoleTableIfNeeded(); + ListCell *item; + + foreach(item, stmt->roles) + { + RoleSpec *spec = lfirst(item); + bool found = false; + RoleEntry *entry = hash_search( + CurrentDdlTable->role_table, + spec->rolename, + HASH_ENTER, + &found); + + entry->type = Op_Delete; + entry->password = NULL; + if (!found) + memset(entry->old_name, 0, sizeof(entry)); + } +} + +static void +HandleRename(RenameStmt *stmt) +{ + if (stmt->renameType == OBJECT_DATABASE) + return HandleDbRename(stmt); + else if (stmt->renameType == OBJECT_ROLE) + return HandleRoleRename(stmt); +} + +static void +NeonProcessUtility( + PlannedStmt *pstmt, + const char *queryString, + bool readOnlyTree, + ProcessUtilityContext context, + ParamListInfo params, + QueryEnvironment *queryEnv, + DestReceiver *dest, + QueryCompletion *qc) +{ + Node *parseTree = pstmt->utilityStmt; + + switch (nodeTag(parseTree)) + { + case T_CreatedbStmt: + HandleCreateDb(castNode(CreatedbStmt, parseTree)); + break; + case T_AlterOwnerStmt: + HandleAlterOwner(castNode(AlterOwnerStmt, parseTree)); + break; + case T_RenameStmt: + HandleRename(castNode(RenameStmt, parseTree)); + break; + case T_DropdbStmt: + HandleDropDb(castNode(DropdbStmt, parseTree)); + break; + case T_CreateRoleStmt: + HandleCreateRole(castNode(CreateRoleStmt, parseTree)); + break; + case T_AlterRoleStmt: + HandleAlterRole(castNode(AlterRoleStmt, parseTree)); + break; + case T_DropRoleStmt: + HandleDropRole(castNode(DropRoleStmt, parseTree)); + break; + default: + break; + } + + if (PreviousProcessUtilityHook) + { + PreviousProcessUtilityHook( + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); + } + else + { + standard_ProcessUtility( + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); + } +} + +extern void +InitControlPlaneConnector() +{ + PreviousProcessUtilityHook = ProcessUtility_hook; + ProcessUtility_hook = NeonProcessUtility; + RegisterXactCallback(NeonXactCallback, NULL); + RegisterSubXactCallback(NeonSubXactCallback, NULL); + + DefineCustomStringVariable( + "neon.console_url", + "URL of the Neon Console, which will be forwarded changes to dbs and roles", + NULL, + &ConsoleURL, + NULL, + PGC_POSTMASTER, + 0, + NULL, + NULL, + NULL); + + DefineCustomBoolVariable( + "neon.forward_ddl", + "Controls whether to forward DDL to the control plane", + NULL, + &ForwardDDL, + true, + PGC_SUSET, + 0, + NULL, + NULL, + NULL); + + const char *jwt_token = getenv("NEON_CONTROL_PLANE_TOKEN"); + + if (!jwt_token) + { + elog(LOG, "Missing NEON_CONTROL_PLANE_TOKEN environment variable, forwarding will not be authenticated"); + } + + if (curl_global_init(CURL_GLOBAL_DEFAULT)) + { + elog(ERROR, "Failed to initialize curl"); + } + if ((CurlHandle = curl_easy_init()) == NULL) + { + elog(ERROR, "Failed to initialize curl handle"); + } + if ((ContentHeader = curl_slist_append(ContentHeader, "Content-Type: application/json")) == NULL) + { + elog(ERROR, "Failed to initialize content header"); + } + + if (jwt_token) + { + char auth_header[8192]; + + snprintf(auth_header, sizeof(auth_header), "Authorization: Bearer %s", jwt_token); + if ((ContentHeader = curl_slist_append(ContentHeader, auth_header)) == NULL) + { + elog(ERROR, "Failed to initialize authorization header"); + } + } +} diff --git a/pgxn/neon/control_plane_connector.h b/pgxn/neon/control_plane_connector.h new file mode 100644 index 0000000000..12d6a97562 --- /dev/null +++ b/pgxn/neon/control_plane_connector.h @@ -0,0 +1,6 @@ +#ifndef CONTROL_PLANE_CONNECTOR_H +#define CONTROL_PLANE_CONNECTOR_H + +void InitControlPlaneConnector(); + +#endif diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 217c1974a0..b45d7cfc32 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -25,6 +25,7 @@ #include "neon.h" #include "walproposer.h" #include "pagestore_client.h" +#include "control_plane_connector.h" PG_MODULE_MAGIC; void _PG_init(void); @@ -34,7 +35,11 @@ _PG_init(void) { pg_init_libpagestore(); pg_init_walproposer(); + InitControlPlaneConnector(); + // Important: This must happen after other parts of the extension + // are loaded, otherwise any settings to GUCs that were set before + // the extension was loaded will be removed. EmitWarningsOnPlaceholders("neon"); } diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index d67f088365..14ae88cc2c 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -27,6 +27,10 @@ class PgVersion(str, enum.Enum): def __repr__(self) -> str: return f"'{self.value}'" + # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums + def __str__(self) -> str: + return self.value + # In GitHub workflows we use Postgres version with v-prefix (e.g. v14 instead of just 14), # sometime we need to do so in tests. @property @@ -78,11 +82,11 @@ def pytest_addoption(parser: Parser): @pytest.fixture(scope="session") def pg_version(request: FixtureRequest) -> Iterator[PgVersion]: if v := request.config.getoption("--pg-version"): - version, source = v, "from --pg-version commad-line argument" + version, source = v, "from --pg-version command-line argument" elif v := os.environ.get("DEFAULT_PG_VERSION"): version, source = PgVersion(v), "from DEFAULT_PG_VERSION environment variable" else: - version, source = DEFAULT_VERSION, "default verson" + version, source = DEFAULT_VERSION, "default version" log.info(f"pg_version is {version} ({source})") yield version diff --git a/test_runner/regress/test_ddl_forwarding.py b/test_runner/regress/test_ddl_forwarding.py new file mode 100644 index 0000000000..27ebd3c181 --- /dev/null +++ b/test_runner/regress/test_ddl_forwarding.py @@ -0,0 +1,219 @@ +from types import TracebackType +from typing import Any, Dict, List, Optional, Tuple, Type + +import psycopg2 +import pytest +from fixtures.log_helper import log +from fixtures.neon_fixtures import ( + PortDistributor, + VanillaPostgres, +) +from pytest_httpserver import HTTPServer +from werkzeug.wrappers.request import Request +from werkzeug.wrappers.response import Response + + +@pytest.fixture(scope="session") +def httpserver_listen_address(port_distributor: PortDistributor): + port = port_distributor.get_port() + return ("localhost", port) + + +def handle_db(dbs, roles, operation): + if operation["op"] == "set": + if "old_name" in operation and operation["old_name"] in dbs: + dbs[operation["name"]] = dbs[operation["old_name"]] + dbs.pop(operation["old_name"]) + if "owner" in operation: + dbs[operation["name"]] = operation["owner"] + elif operation["op"] == "del": + dbs.pop(operation["name"]) + else: + raise ValueError("Invalid op") + + +def handle_role(dbs, roles, operation): + if operation["op"] == "set": + if "old_name" in operation and operation["old_name"] in roles: + roles[operation["name"]] = roles[operation["old_name"]] + roles.pop(operation["old_name"]) + for db, owner in dbs.items(): + if owner == operation["old_name"]: + dbs[db] = operation["name"] + if "password" in operation: + roles[operation["name"]] = operation["password"] + elif operation["op"] == "del": + if "old_name" in operation: + roles.pop(operation["old_name"]) + roles.pop(operation["name"]) + else: + raise ValueError("Invalid op") + + +fail = False + + +def ddl_forward_handler(request: Request, dbs: Dict[str, str], roles: Dict[str, str]) -> Response: + log.info(f"Received request with data {request.get_data(as_text=True)}") + if fail: + log.info("FAILING") + return Response(status=500, response="Failed just cuz") + if request.json is None: + log.info("Received invalid JSON") + return Response(status=400) + json = request.json + # Handle roles first + if "roles" in json: + for operation in json["roles"]: + handle_role(dbs, roles, operation) + if "dbs" in json: + for operation in json["dbs"]: + handle_db(dbs, roles, operation) + return Response(status=200) + + +class DdlForwardingContext: + def __init__(self, httpserver: HTTPServer, vanilla_pg: VanillaPostgres, host: str, port: int): + self.server = httpserver + self.pg = vanilla_pg + self.host = host + self.port = port + self.dbs: Dict[str, str] = {} + self.roles: Dict[str, str] = {} + endpoint = "/management/api/v2/roles_and_databases" + ddl_url = f"http://{host}:{port}{endpoint}" + self.pg.configure( + [ + f"neon.console_url={ddl_url}", + "shared_preload_libraries = 'neon'", + ] + ) + log.info(f"Listening on {ddl_url}") + self.server.expect_request(endpoint, method="PATCH").respond_with_handler( + lambda request: ddl_forward_handler(request, self.dbs, self.roles) + ) + + def __enter__(self): + self.pg.start() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ): + self.pg.stop() + + def send(self, query: str) -> List[Tuple[Any, ...]]: + return self.pg.safe_psql(query) + + def wait(self, timeout=3): + self.server.wait(timeout=timeout) + + def send_and_wait(self, query: str, timeout=3) -> List[Tuple[Any, ...]]: + res = self.send(query) + self.wait(timeout=timeout) + return res + + +@pytest.fixture(scope="function") +def ddl( + httpserver: HTTPServer, vanilla_pg: VanillaPostgres, httpserver_listen_address: tuple[str, int] +): + (host, port) = httpserver_listen_address + with DdlForwardingContext(httpserver, vanilla_pg, host, port) as ddl: + yield ddl + + +def test_ddl_forwarding(ddl: DdlForwardingContext): + curr_user = ddl.send("SELECT current_user")[0][0] + log.info(f"Current user is {curr_user}") + ddl.send_and_wait("CREATE DATABASE bork") + assert ddl.dbs == {"bork": curr_user} + ddl.send_and_wait("CREATE ROLE volk WITH PASSWORD 'nu_zayats'") + ddl.send_and_wait("ALTER DATABASE bork RENAME TO nu_pogodi") + assert ddl.dbs == {"nu_pogodi": curr_user} + ddl.send_and_wait("ALTER DATABASE nu_pogodi OWNER TO volk") + assert ddl.dbs == {"nu_pogodi": "volk"} + ddl.send_and_wait("DROP DATABASE nu_pogodi") + assert ddl.dbs == {} + ddl.send_and_wait("DROP ROLE volk") + assert ddl.roles == {} + + ddl.send_and_wait("CREATE ROLE tarzan WITH PASSWORD 'of_the_apes'") + assert ddl.roles == {"tarzan": "of_the_apes"} + ddl.send_and_wait("DROP ROLE tarzan") + assert ddl.roles == {} + ddl.send_and_wait("CREATE ROLE tarzan WITH PASSWORD 'of_the_apes'") + assert ddl.roles == {"tarzan": "of_the_apes"} + ddl.send_and_wait("ALTER ROLE tarzan WITH PASSWORD 'jungle_man'") + assert ddl.roles == {"tarzan": "jungle_man"} + ddl.send_and_wait("ALTER ROLE tarzan RENAME TO mowgli") + assert ddl.roles == {"mowgli": "jungle_man"} + ddl.send_and_wait("DROP ROLE mowgli") + assert ddl.roles == {} + + conn = ddl.pg.connect() + cur = conn.cursor() + + cur.execute("BEGIN") + cur.execute("CREATE ROLE bork WITH PASSWORD 'cork'") + cur.execute("COMMIT") + ddl.wait() + assert ddl.roles == {"bork": "cork"} + cur.execute("BEGIN") + cur.execute("CREATE ROLE stork WITH PASSWORD 'pork'") + cur.execute("ABORT") + ddl.wait() + assert ("stork", "pork") not in ddl.roles.items() + cur.execute("BEGIN") + cur.execute("ALTER ROLE bork WITH PASSWORD 'pork'") + cur.execute("ALTER ROLE bork RENAME TO stork") + cur.execute("COMMIT") + ddl.wait() + assert ddl.roles == {"stork": "pork"} + cur.execute("BEGIN") + cur.execute("CREATE ROLE dork WITH PASSWORD 'york'") + cur.execute("SAVEPOINT point") + cur.execute("ALTER ROLE dork WITH PASSWORD 'zork'") + cur.execute("ALTER ROLE dork RENAME TO fork") + cur.execute("ROLLBACK TO SAVEPOINT point") + cur.execute("ALTER ROLE dork WITH PASSWORD 'fork'") + cur.execute("ALTER ROLE dork RENAME TO zork") + cur.execute("RELEASE SAVEPOINT point") + cur.execute("COMMIT") + ddl.wait() + assert ddl.roles == {"stork": "pork", "zork": "fork"} + + cur.execute("DROP ROLE stork") + cur.execute("DROP ROLE zork") + ddl.wait() + assert ddl.roles == {} + + cur.execute("CREATE ROLE bork WITH PASSWORD 'dork'") + cur.execute("CREATE ROLE stork WITH PASSWORD 'cork'") + cur.execute("BEGIN") + cur.execute("DROP ROLE bork") + cur.execute("ALTER ROLE stork RENAME TO bork") + cur.execute("COMMIT") + ddl.wait() + assert ddl.roles == {"bork": "cork"} + + cur.execute("DROP ROLE bork") + ddl.wait() + assert ddl.roles == {} + + cur.execute("CREATE ROLE bork WITH PASSWORD 'dork'") + cur.execute("CREATE DATABASE stork WITH OWNER=bork") + cur.execute("ALTER ROLE bork RENAME TO cork") + ddl.wait() + assert ddl.dbs == {"stork": "cork"} + + with pytest.raises(psycopg2.InternalError): + global fail + fail = True + cur.execute("CREATE DATABASE failure WITH OWNER=cork") + ddl.wait() + + conn.close()