From 35b77a138f6884f2e1ffc5f9648e92596970998a Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 15 Mar 2023 13:32:08 +0200 Subject: [PATCH] Support "kick" callback from console in the HTTP API. We cannot remove the old libpq-based "mgmt" interface until the control plane has been changed to use the new HTTP API, and it has been deployed to all the environments. But this gets us started with the migration to the HTTP API. Fixes https://github.com/neondatabase/neon/issues/1094 --- proxy/src/console/mgmt.rs | 15 ++++-- proxy/src/http/server.rs | 31 +++++++++++-- test_runner/fixtures/neon_fixtures.py | 66 ++++++++++++++++----------- test_runner/regress/test_proxy.py | 12 ++++- 4 files changed, 88 insertions(+), 36 deletions(-) diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index 30364be6f4..912383a611 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N CPLANE_WAITERS.notify(psql_session_id, msg) } -/// Console management API listener task. -/// It spawns console response handlers needed for the link auth. +/// Listener task for mgmt connections using the libpq protocol, for +/// the purposes of "kick callback" from proxy, for link authentication. +/// +/// The protocol is the libpq protocol, but the only "query" we accept is a +/// JSON document. The JSON document is a KickSession serialized to JSON. +/// +/// This is considered legacy now. The preferred way to deliver "kick +/// callbacks" is now via the HTTP API. See `server.rs`. Once the control +/// plane has switched to using the HTTP API, this can be removed. pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); @@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { .set_nodelay(true) .context("failed to set client socket option")?; + // spawn a task to handle this connection tokio::task::spawn(async move { let span = info_span!("mgmt", peer = %peer_addr); let _enter = span.enter(); - info!("started a new console management API thread"); + info!("started a new console management API task"); scopeguard::defer! { info!("console management API thread is about to finish"); } @@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { /// A message received by `mgmt` when a compute node is ready. pub type ComputeReady = Result; -// TODO: replace with an http-based protocol. struct MgmtHandler; #[async_trait::async_trait] impl postgres_backend::Handler for MgmtHandler { diff --git a/proxy/src/http/server.rs b/proxy/src/http/server.rs index f35f4f9a62..79528a8f06 100644 --- a/proxy/src/http/server.rs +++ b/proxy/src/http/server.rs @@ -1,5 +1,8 @@ -use anyhow::anyhow; -use hyper::{Body, Request, Response, StatusCode}; +use crate::console; +use crate::console::messages::KickSession; +use crate::waiters::NotifyError; +use anyhow::{anyhow, Context}; +use hyper::{body, Body, Request, Response, StatusCode}; use std::net::TcpListener; use tracing::info; use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; @@ -8,8 +11,30 @@ async fn status_handler(_: Request) -> Result, ApiError> { json_response(StatusCode::OK, "") } +/// Process a session kick callback from the control plane. The body is a +/// KickSession as a JSON document. +/// +/// TODO: authentication +async fn kick_session_handler(req: Request) -> Result, ApiError> { + let body = &body::to_bytes(req.into_body()) + .await + .context("Failed to get request body") + .map_err(ApiError::BadRequest)?; + let kick_session_json: KickSession = serde_json::from_slice(body) + .context("Failed to parse query as json") + .map_err(ApiError::BadRequest)?; + + match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) { + Ok(()) => json_response(StatusCode::OK, ""), + Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))), + Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))), + } +} + fn make_router() -> RouterBuilder { - endpoint::make_router().get("/v1/status", status_handler) + endpoint::make_router() + .get("/v1/status", status_handler) + .post("/v1/kick_session", kick_session_handler) } pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c5e260a962..9b1ef821a7 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -25,6 +25,7 @@ from types import TracebackType from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast from urllib.parse import urlparse +import aiohttp import asyncpg import backoff # type: ignore import boto3 @@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol): @staticmethod async def activate_link_auth( - local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True + local_vanilla_pg, + proxy_with_metric_collector, + psql_session_id, + create_user=True, + use_legacy_mgmt_api=False, ): pg_user = "proxy" @@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol): local_vanilla_pg.start() local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") - db_info = json.dumps( - { - "session_id": psql_session_id, - "result": { - "Success": { - "host": local_vanilla_pg.default_options["host"], - "port": local_vanilla_pg.default_options["port"], - "dbname": local_vanilla_pg.default_options["dbname"], - "user": pg_user, - "aux": { - "project_id": "test_project_id", - "endpoint_id": "test_endpoint_id", - "branch_id": "test_branch_id", - }, - } - }, - } - ) + db_info = { + "session_id": psql_session_id, + "result": { + "Success": { + "host": local_vanilla_pg.default_options["host"], + "port": local_vanilla_pg.default_options["port"], + "dbname": local_vanilla_pg.default_options["dbname"], + "user": pg_user, + "aux": { + "project_id": "test_project_id", + "endpoint_id": "test_endpoint_id", + "branch_id": "test_branch_id", + }, + } + }, + } - log.info("sending session activation message") - psql = await PSQL( - host=proxy_with_metric_collector.host, - port=proxy_with_metric_collector.mgmt_port, - ).run(db_info) - assert psql.stdout is not None - out = (await psql.stdout.read()).decode("utf-8").strip() - assert out == "ok" + if use_legacy_mgmt_api: + log.info("sending session activation message using legacy libpq mgmt interface") + psql = await PSQL( + host=proxy_with_metric_collector.host, + port=proxy_with_metric_collector.mgmt_port, + ).run(json.dumps(db_info)) + assert psql.stdout is not None + out = (await psql.stdout.read()).decode("utf-8").strip() + assert out == "ok" + else: + log.info("sending session activation message using HTTP mgmt interface") + async with aiohttp.request( + "POST", + f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session", + json=db_info, + ) as resp: + assert resp.ok @pytest.fixture(scope="function") diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 51fabdd2a1..489a9442e7 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy): @pytest.mark.asyncio -async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): +@pytest.mark.parametrize("use_legacy_mgmt_api", [True, False]) +async def test_link_auth( + vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool +): """ Check the Link auth flow: a lightweight auth method which delegates all necessary checks to the console by sending client an auth URL. @@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): link = await NeonProxy.find_auth_link(base_uri, psql) psql_session_id = NeonProxy.get_session_id(base_uri, link) - await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id) + await NeonProxy.activate_link_auth( + vanilla_pg, + link_proxy, + psql_session_id, + use_legacy_mgmt_api=use_legacy_mgmt_api, + ) assert psql.stdout is not None out = (await psql.stdout.read()).decode("utf-8").strip()