diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index 30364be6f4..912383a611 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N CPLANE_WAITERS.notify(psql_session_id, msg) } -/// Console management API listener task. -/// It spawns console response handlers needed for the link auth. +/// Listener task for mgmt connections using the libpq protocol, for +/// the purposes of "kick callback" from proxy, for link authentication. +/// +/// The protocol is the libpq protocol, but the only "query" we accept is a +/// JSON document. The JSON document is a KickSession serialized to JSON. +/// +/// This is considered legacy now. The preferred way to deliver "kick +/// callbacks" is now via the HTTP API. See `server.rs`. Once the control +/// plane has switched to using the HTTP API, this can be removed. pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); @@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { .set_nodelay(true) .context("failed to set client socket option")?; + // spawn a task to handle this connection tokio::task::spawn(async move { let span = info_span!("mgmt", peer = %peer_addr); let _enter = span.enter(); - info!("started a new console management API thread"); + info!("started a new console management API task"); scopeguard::defer! { info!("console management API thread is about to finish"); } @@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { /// A message received by `mgmt` when a compute node is ready. pub type ComputeReady = Result; -// TODO: replace with an http-based protocol. struct MgmtHandler; #[async_trait::async_trait] impl postgres_backend::Handler for MgmtHandler { diff --git a/proxy/src/http/server.rs b/proxy/src/http/server.rs index f35f4f9a62..79528a8f06 100644 --- a/proxy/src/http/server.rs +++ b/proxy/src/http/server.rs @@ -1,5 +1,8 @@ -use anyhow::anyhow; -use hyper::{Body, Request, Response, StatusCode}; +use crate::console; +use crate::console::messages::KickSession; +use crate::waiters::NotifyError; +use anyhow::{anyhow, Context}; +use hyper::{body, Body, Request, Response, StatusCode}; use std::net::TcpListener; use tracing::info; use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; @@ -8,8 +11,30 @@ async fn status_handler(_: Request) -> Result, ApiError> { json_response(StatusCode::OK, "") } +/// Process a session kick callback from the control plane. The body is a +/// KickSession as a JSON document. +/// +/// TODO: authentication +async fn kick_session_handler(req: Request) -> Result, ApiError> { + let body = &body::to_bytes(req.into_body()) + .await + .context("Failed to get request body") + .map_err(ApiError::BadRequest)?; + let kick_session_json: KickSession = serde_json::from_slice(body) + .context("Failed to parse query as json") + .map_err(ApiError::BadRequest)?; + + match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) { + Ok(()) => json_response(StatusCode::OK, ""), + Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))), + Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))), + } +} + fn make_router() -> RouterBuilder { - endpoint::make_router().get("/v1/status", status_handler) + endpoint::make_router() + .get("/v1/status", status_handler) + .post("/v1/kick_session", kick_session_handler) } pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c5e260a962..9b1ef821a7 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -25,6 +25,7 @@ from types import TracebackType from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast from urllib.parse import urlparse +import aiohttp import asyncpg import backoff # type: ignore import boto3 @@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol): @staticmethod async def activate_link_auth( - local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True + local_vanilla_pg, + proxy_with_metric_collector, + psql_session_id, + create_user=True, + use_legacy_mgmt_api=False, ): pg_user = "proxy" @@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol): local_vanilla_pg.start() local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") - db_info = json.dumps( - { - "session_id": psql_session_id, - "result": { - "Success": { - "host": local_vanilla_pg.default_options["host"], - "port": local_vanilla_pg.default_options["port"], - "dbname": local_vanilla_pg.default_options["dbname"], - "user": pg_user, - "aux": { - "project_id": "test_project_id", - "endpoint_id": "test_endpoint_id", - "branch_id": "test_branch_id", - }, - } - }, - } - ) + db_info = { + "session_id": psql_session_id, + "result": { + "Success": { + "host": local_vanilla_pg.default_options["host"], + "port": local_vanilla_pg.default_options["port"], + "dbname": local_vanilla_pg.default_options["dbname"], + "user": pg_user, + "aux": { + "project_id": "test_project_id", + "endpoint_id": "test_endpoint_id", + "branch_id": "test_branch_id", + }, + } + }, + } - log.info("sending session activation message") - psql = await PSQL( - host=proxy_with_metric_collector.host, - port=proxy_with_metric_collector.mgmt_port, - ).run(db_info) - assert psql.stdout is not None - out = (await psql.stdout.read()).decode("utf-8").strip() - assert out == "ok" + if use_legacy_mgmt_api: + log.info("sending session activation message using legacy libpq mgmt interface") + psql = await PSQL( + host=proxy_with_metric_collector.host, + port=proxy_with_metric_collector.mgmt_port, + ).run(json.dumps(db_info)) + assert psql.stdout is not None + out = (await psql.stdout.read()).decode("utf-8").strip() + assert out == "ok" + else: + log.info("sending session activation message using HTTP mgmt interface") + async with aiohttp.request( + "POST", + f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session", + json=db_info, + ) as resp: + assert resp.ok @pytest.fixture(scope="function") diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 51fabdd2a1..489a9442e7 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy): @pytest.mark.asyncio -async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): +@pytest.mark.parametrize("use_legacy_mgmt_api", [True, False]) +async def test_link_auth( + vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool +): """ Check the Link auth flow: a lightweight auth method which delegates all necessary checks to the console by sending client an auth URL. @@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): link = await NeonProxy.find_auth_link(base_uri, psql) psql_session_id = NeonProxy.get_session_id(base_uri, link) - await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id) + await NeonProxy.activate_link_auth( + vanilla_pg, + link_proxy, + psql_session_id, + use_legacy_mgmt_api=use_legacy_mgmt_api, + ) assert psql.stdout is not None out = (await psql.stdout.read()).decode("utf-8").strip()