Compare commits

...

1 Commits

Author SHA1 Message Date
Heikki Linnakangas
35b77a138f Support "kick" callback from console in the HTTP API.
We cannot remove the old libpq-based "mgmt" interface until the
control plane has been changed to use the new HTTP API, and it has
been deployed to all the environments. But this gets us started with
the migration to the HTTP API.

Fixes https://github.com/neondatabase/neon/issues/1094
2023-03-15 13:32:08 +02:00
4 changed files with 88 additions and 36 deletions

View File

@@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
CPLANE_WAITERS.notify(psql_session_id, msg) CPLANE_WAITERS.notify(psql_session_id, msg)
} }
/// Console management API listener task. /// Listener task for mgmt connections using the libpq protocol, for
/// It spawns console response handlers needed for the link auth. /// the purposes of "kick callback" from proxy, for link authentication.
///
/// The protocol is the libpq protocol, but the only "query" we accept is a
/// JSON document. The JSON document is a KickSession serialized to JSON.
///
/// This is considered legacy now. The preferred way to deliver "kick
/// callbacks" is now via the HTTP API. See `server.rs`. Once the control
/// plane has switched to using the HTTP API, this can be removed.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! { scopeguard::defer! {
info!("mgmt has shut down"); info!("mgmt has shut down");
@@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
.set_nodelay(true) .set_nodelay(true)
.context("failed to set client socket option")?; .context("failed to set client socket option")?;
// spawn a task to handle this connection
tokio::task::spawn(async move { tokio::task::spawn(async move {
let span = info_span!("mgmt", peer = %peer_addr); let span = info_span!("mgmt", peer = %peer_addr);
let _enter = span.enter(); let _enter = span.enter();
info!("started a new console management API thread"); info!("started a new console management API task");
scopeguard::defer! { scopeguard::defer! {
info!("console management API thread is about to finish"); info!("console management API thread is about to finish");
} }
@@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
/// A message received by `mgmt` when a compute node is ready. /// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<DatabaseInfo, String>; pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler; struct MgmtHandler;
#[async_trait::async_trait] #[async_trait::async_trait]
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler { impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {

View File

@@ -1,5 +1,8 @@
use anyhow::anyhow; use crate::console;
use hyper::{Body, Request, Response, StatusCode}; use crate::console::messages::KickSession;
use crate::waiters::NotifyError;
use anyhow::{anyhow, Context};
use hyper::{body, Body, Request, Response, StatusCode};
use std::net::TcpListener; use std::net::TcpListener;
use tracing::info; use tracing::info;
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
@@ -8,8 +11,30 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "") json_response(StatusCode::OK, "")
} }
/// Process a session kick callback from the control plane. The body is a
/// KickSession as a JSON document.
///
/// TODO: authentication
async fn kick_session_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let body = &body::to_bytes(req.into_body())
.await
.context("Failed to get request body")
.map_err(ApiError::BadRequest)?;
let kick_session_json: KickSession = serde_json::from_slice(body)
.context("Failed to parse query as json")
.map_err(ApiError::BadRequest)?;
match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) {
Ok(()) => json_response(StatusCode::OK, ""),
Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))),
Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))),
}
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> { fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router().get("/v1/status", status_handler) endpoint::make_router()
.get("/v1/status", status_handler)
.post("/v1/kick_session", kick_session_handler)
} }
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {

View File

@@ -25,6 +25,7 @@ from types import TracebackType
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp
import asyncpg import asyncpg
import backoff # type: ignore import backoff # type: ignore
import boto3 import boto3
@@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol):
@staticmethod @staticmethod
async def activate_link_auth( async def activate_link_auth(
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True local_vanilla_pg,
proxy_with_metric_collector,
psql_session_id,
create_user=True,
use_legacy_mgmt_api=False,
): ):
pg_user = "proxy" pg_user = "proxy"
@@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol):
local_vanilla_pg.start() local_vanilla_pg.start()
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
db_info = json.dumps( db_info = {
{ "session_id": psql_session_id,
"session_id": psql_session_id, "result": {
"result": { "Success": {
"Success": { "host": local_vanilla_pg.default_options["host"],
"host": local_vanilla_pg.default_options["host"], "port": local_vanilla_pg.default_options["port"],
"port": local_vanilla_pg.default_options["port"], "dbname": local_vanilla_pg.default_options["dbname"],
"dbname": local_vanilla_pg.default_options["dbname"], "user": pg_user,
"user": pg_user, "aux": {
"aux": { "project_id": "test_project_id",
"project_id": "test_project_id", "endpoint_id": "test_endpoint_id",
"endpoint_id": "test_endpoint_id", "branch_id": "test_branch_id",
"branch_id": "test_branch_id", },
}, }
} },
}, }
}
)
log.info("sending session activation message") if use_legacy_mgmt_api:
psql = await PSQL( log.info("sending session activation message using legacy libpq mgmt interface")
host=proxy_with_metric_collector.host, psql = await PSQL(
port=proxy_with_metric_collector.mgmt_port, host=proxy_with_metric_collector.host,
).run(db_info) port=proxy_with_metric_collector.mgmt_port,
assert psql.stdout is not None ).run(json.dumps(db_info))
out = (await psql.stdout.read()).decode("utf-8").strip() assert psql.stdout is not None
assert out == "ok" out = (await psql.stdout.read()).decode("utf-8").strip()
assert out == "ok"
else:
log.info("sending session activation message using HTTP mgmt interface")
async with aiohttp.request(
"POST",
f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session",
json=db_info,
) as resp:
assert resp.ok
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

View File

@@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): @pytest.mark.parametrize("use_legacy_mgmt_api", [True, False])
async def test_link_auth(
vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool
):
""" """
Check the Link auth flow: a lightweight auth method which delegates Check the Link auth flow: a lightweight auth method which delegates
all necessary checks to the console by sending client an auth URL. all necessary checks to the console by sending client an auth URL.
@@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
link = await NeonProxy.find_auth_link(base_uri, psql) link = await NeonProxy.find_auth_link(base_uri, psql)
psql_session_id = NeonProxy.get_session_id(base_uri, link) psql_session_id = NeonProxy.get_session_id(base_uri, link)
await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id) await NeonProxy.activate_link_auth(
vanilla_pg,
link_proxy,
psql_session_id,
use_legacy_mgmt_api=use_legacy_mgmt_api,
)
assert psql.stdout is not None assert psql.stdout is not None
out = (await psql.stdout.read()).decode("utf-8").strip() out = (await psql.stdout.read()).decode("utf-8").strip()