mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 23:20:40 +00:00
Support "kick" callback from console in the HTTP API.
We cannot remove the old libpq-based "mgmt" interface until the control plane has been changed to use the new HTTP API, and it has been deployed to all the environments. But this gets us started with the migration to the HTTP API. Fixes https://github.com/neondatabase/neon/issues/1094
This commit is contained in:
@@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
/// Listener task for mgmt connections using the libpq protocol, for
|
||||
/// the purposes of "kick callback" from proxy, for link authentication.
|
||||
///
|
||||
/// The protocol is the libpq protocol, but the only "query" we accept is a
|
||||
/// JSON document. The JSON document is a KickSession serialized to JSON.
|
||||
///
|
||||
/// This is considered legacy now. The preferred way to deliver "kick
|
||||
/// callbacks" is now via the HTTP API. See `server.rs`. Once the control
|
||||
/// plane has switched to using the HTTP API, this can be removed.
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
@@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
// spawn a task to handle this connection
|
||||
tokio::task::spawn(async move {
|
||||
let span = info_span!("mgmt", peer = %peer_addr);
|
||||
let _enter = span.enter();
|
||||
|
||||
info!("started a new console management API thread");
|
||||
info!("started a new console management API task");
|
||||
scopeguard::defer! {
|
||||
info!("console management API thread is about to finish");
|
||||
}
|
||||
@@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use crate::console;
|
||||
use crate::console::messages::KickSession;
|
||||
use crate::waiters::NotifyError;
|
||||
use anyhow::{anyhow, Context};
|
||||
use hyper::{body, Body, Request, Response, StatusCode};
|
||||
use std::net::TcpListener;
|
||||
use tracing::info;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
@@ -8,8 +11,30 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
}
|
||||
|
||||
/// Process a session kick callback from the control plane. The body is a
|
||||
/// KickSession as a JSON document.
|
||||
///
|
||||
/// TODO: authentication
|
||||
async fn kick_session_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let body = &body::to_bytes(req.into_body())
|
||||
.await
|
||||
.context("Failed to get request body")
|
||||
.map_err(ApiError::BadRequest)?;
|
||||
let kick_session_json: KickSession = serde_json::from_slice(body)
|
||||
.context("Failed to parse query as json")
|
||||
.map_err(ApiError::BadRequest)?;
|
||||
|
||||
match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) {
|
||||
Ok(()) => json_response(StatusCode::OK, ""),
|
||||
Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))),
|
||||
Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
endpoint::make_router().get("/v1/status", status_handler)
|
||||
endpoint::make_router()
|
||||
.get("/v1/status", status_handler)
|
||||
.post("/v1/kick_session", kick_session_handler)
|
||||
}
|
||||
|
||||
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
|
||||
@@ -25,6 +25,7 @@ from types import TracebackType
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import asyncpg
|
||||
import backoff # type: ignore
|
||||
import boto3
|
||||
@@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol):
|
||||
|
||||
@staticmethod
|
||||
async def activate_link_auth(
|
||||
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True
|
||||
local_vanilla_pg,
|
||||
proxy_with_metric_collector,
|
||||
psql_session_id,
|
||||
create_user=True,
|
||||
use_legacy_mgmt_api=False,
|
||||
):
|
||||
pg_user = "proxy"
|
||||
|
||||
@@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol):
|
||||
local_vanilla_pg.start()
|
||||
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
|
||||
|
||||
db_info = json.dumps(
|
||||
{
|
||||
"session_id": psql_session_id,
|
||||
"result": {
|
||||
"Success": {
|
||||
"host": local_vanilla_pg.default_options["host"],
|
||||
"port": local_vanilla_pg.default_options["port"],
|
||||
"dbname": local_vanilla_pg.default_options["dbname"],
|
||||
"user": pg_user,
|
||||
"aux": {
|
||||
"project_id": "test_project_id",
|
||||
"endpoint_id": "test_endpoint_id",
|
||||
"branch_id": "test_branch_id",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
db_info = {
|
||||
"session_id": psql_session_id,
|
||||
"result": {
|
||||
"Success": {
|
||||
"host": local_vanilla_pg.default_options["host"],
|
||||
"port": local_vanilla_pg.default_options["port"],
|
||||
"dbname": local_vanilla_pg.default_options["dbname"],
|
||||
"user": pg_user,
|
||||
"aux": {
|
||||
"project_id": "test_project_id",
|
||||
"endpoint_id": "test_endpoint_id",
|
||||
"branch_id": "test_branch_id",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
log.info("sending session activation message")
|
||||
psql = await PSQL(
|
||||
host=proxy_with_metric_collector.host,
|
||||
port=proxy_with_metric_collector.mgmt_port,
|
||||
).run(db_info)
|
||||
assert psql.stdout is not None
|
||||
out = (await psql.stdout.read()).decode("utf-8").strip()
|
||||
assert out == "ok"
|
||||
if use_legacy_mgmt_api:
|
||||
log.info("sending session activation message using legacy libpq mgmt interface")
|
||||
psql = await PSQL(
|
||||
host=proxy_with_metric_collector.host,
|
||||
port=proxy_with_metric_collector.mgmt_port,
|
||||
).run(json.dumps(db_info))
|
||||
assert psql.stdout is not None
|
||||
out = (await psql.stdout.read()).decode("utf-8").strip()
|
||||
assert out == "ok"
|
||||
else:
|
||||
log.info("sending session activation message using HTTP mgmt interface")
|
||||
async with aiohttp.request(
|
||||
"POST",
|
||||
f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session",
|
||||
json=db_info,
|
||||
) as resp:
|
||||
assert resp.ok
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
|
||||
@pytest.mark.parametrize("use_legacy_mgmt_api", [True, False])
|
||||
async def test_link_auth(
|
||||
vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool
|
||||
):
|
||||
"""
|
||||
Check the Link auth flow: a lightweight auth method which delegates
|
||||
all necessary checks to the console by sending client an auth URL.
|
||||
@@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
|
||||
link = await NeonProxy.find_auth_link(base_uri, psql)
|
||||
|
||||
psql_session_id = NeonProxy.get_session_id(base_uri, link)
|
||||
await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id)
|
||||
await NeonProxy.activate_link_auth(
|
||||
vanilla_pg,
|
||||
link_proxy,
|
||||
psql_session_id,
|
||||
use_legacy_mgmt_api=use_legacy_mgmt_api,
|
||||
)
|
||||
|
||||
assert psql.stdout is not None
|
||||
out = (await psql.stdout.read()).decode("utf-8").strip()
|
||||
|
||||
Reference in New Issue
Block a user