diff --git a/pgxn/neon/libpqwalproposer.c b/pgxn/neon/libpqwalproposer.c index ed3b8a817f..ce9a1475d3 100644 --- a/pgxn/neon/libpqwalproposer.c +++ b/pgxn/neon/libpqwalproposer.c @@ -74,7 +74,7 @@ walprop_connect_start(char *conninfo, char *password) if (password) { keywords[n] = "password"; - values[n] = neon_auth_token; + values[n] = password; n++; } keywords[n] = "dbname"; diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 807fd5c91b..d9999ef2b1 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -1393,8 +1393,22 @@ WalProposerRecovery(int donor, TimeLineID timeline, XLogRecPtr startpos, XLogRec char *err; WalReceiverConn *wrconn; WalRcvStreamOptions options; + char conninfo[MAXCONNINFO]; - wrconn = walrcv_connect(safekeeper[donor].conninfo, false, "wal_proposer_recovery", &err); + if (!neon_auth_token) + { + memcpy(conninfo, safekeeper[donor].conninfo, MAXCONNINFO); + } + else + { + int written = 0; + + written = snprintf((char *) conninfo, MAXCONNINFO, "password=%s %s", neon_auth_token, safekeeper[donor].conninfo); + if (written > MAXCONNINFO || written < 0) + elog(FATAL, "could not append password to the safekeeper connection string"); + } + + wrconn = walrcv_connect(conninfo, false, "wal_proposer_recovery", &err); if (!wrconn) { ereport(WARNING, diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 6695819899..c471b18db7 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -981,6 +981,35 @@ def test_sk_auth(neon_env_builder: NeonEnvBuilder): connector.safe_psql("IDENTIFY_SYSTEM", port=sk.port.pg_tenant_only, password=tenant_token) +# Try restarting endpoint with enabled auth. +def test_restart_endpoint(neon_env_builder: NeonEnvBuilder): + neon_env_builder.auth_enabled = True + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.init_start() + + env.neon_cli.create_branch("test_sk_auth_restart_endpoint") + endpoint = env.endpoints.create_start("test_sk_auth_restart_endpoint") + + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("create table t(i int)") + + # Restarting endpoints and random safekeepers, to trigger recovery. + for _i in range(3): + random_sk = random.choice(env.safekeepers) + random_sk.stop() + + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + start = random.randint(1, 100000) + end = start + random.randint(1, 10000) + cur.execute("insert into t select generate_series(%s,%s)", (start, end)) + + endpoint.stop() + random_sk.start() + endpoint.start() + + class SafekeeperEnv: def __init__( self,