Compare commits

...

1 Commits

Author SHA1 Message Date
Heikki Linnakangas
35b77a138f Support "kick" callback from console in the HTTP API.
We cannot remove the old libpq-based "mgmt" interface until the
control plane has been changed to use the new HTTP API, and it has
been deployed to all the environments. But this gets us started with
the migration to the HTTP API.

Fixes https://github.com/neondatabase/neon/issues/1094
2023-03-15 13:32:08 +02:00
4 changed files with 88 additions and 36 deletions

View File

@@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Console management API listener task.
/// It spawns console response handlers needed for the link auth.
/// Listener task for mgmt connections using the libpq protocol, for
/// the purposes of "kick callback" from proxy, for link authentication.
///
/// The protocol is the libpq protocol, but the only "query" we accept is a
/// JSON document. The JSON document is a KickSession serialized to JSON.
///
/// This is considered legacy now. The preferred way to deliver "kick
/// callbacks" is now via the HTTP API. See `server.rs`. Once the control
/// plane has switched to using the HTTP API, this can be removed.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
info!("mgmt has shut down");
@@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
.set_nodelay(true)
.context("failed to set client socket option")?;
// spawn a task to handle this connection
tokio::task::spawn(async move {
let span = info_span!("mgmt", peer = %peer_addr);
let _enter = span.enter();
info!("started a new console management API thread");
info!("started a new console management API task");
scopeguard::defer! {
info!("console management API thread is about to finish");
}
@@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {

View File

@@ -1,5 +1,8 @@
use anyhow::anyhow;
use hyper::{Body, Request, Response, StatusCode};
use crate::console;
use crate::console::messages::KickSession;
use crate::waiters::NotifyError;
use anyhow::{anyhow, Context};
use hyper::{body, Body, Request, Response, StatusCode};
use std::net::TcpListener;
use tracing::info;
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
@@ -8,8 +11,30 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
}
/// Process a session kick callback from the control plane. The body is a
/// KickSession as a JSON document.
///
/// TODO: authentication
async fn kick_session_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let body = &body::to_bytes(req.into_body())
.await
.context("Failed to get request body")
.map_err(ApiError::BadRequest)?;
let kick_session_json: KickSession = serde_json::from_slice(body)
.context("Failed to parse query as json")
.map_err(ApiError::BadRequest)?;
match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) {
Ok(()) => json_response(StatusCode::OK, ""),
Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))),
Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))),
}
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router().get("/v1/status", status_handler)
endpoint::make_router()
.get("/v1/status", status_handler)
.post("/v1/kick_session", kick_session_handler)
}
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {

View File

@@ -25,6 +25,7 @@ from types import TracebackType
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
from urllib.parse import urlparse
import aiohttp
import asyncpg
import backoff # type: ignore
import boto3
@@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol):
@staticmethod
async def activate_link_auth(
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True
local_vanilla_pg,
proxy_with_metric_collector,
psql_session_id,
create_user=True,
use_legacy_mgmt_api=False,
):
pg_user = "proxy"
@@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol):
local_vanilla_pg.start()
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
db_info = json.dumps(
{
"session_id": psql_session_id,
"result": {
"Success": {
"host": local_vanilla_pg.default_options["host"],
"port": local_vanilla_pg.default_options["port"],
"dbname": local_vanilla_pg.default_options["dbname"],
"user": pg_user,
"aux": {
"project_id": "test_project_id",
"endpoint_id": "test_endpoint_id",
"branch_id": "test_branch_id",
},
}
},
}
)
db_info = {
"session_id": psql_session_id,
"result": {
"Success": {
"host": local_vanilla_pg.default_options["host"],
"port": local_vanilla_pg.default_options["port"],
"dbname": local_vanilla_pg.default_options["dbname"],
"user": pg_user,
"aux": {
"project_id": "test_project_id",
"endpoint_id": "test_endpoint_id",
"branch_id": "test_branch_id",
},
}
},
}
log.info("sending session activation message")
psql = await PSQL(
host=proxy_with_metric_collector.host,
port=proxy_with_metric_collector.mgmt_port,
).run(db_info)
assert psql.stdout is not None
out = (await psql.stdout.read()).decode("utf-8").strip()
assert out == "ok"
if use_legacy_mgmt_api:
log.info("sending session activation message using legacy libpq mgmt interface")
psql = await PSQL(
host=proxy_with_metric_collector.host,
port=proxy_with_metric_collector.mgmt_port,
).run(json.dumps(db_info))
assert psql.stdout is not None
out = (await psql.stdout.read()).decode("utf-8").strip()
assert out == "ok"
else:
log.info("sending session activation message using HTTP mgmt interface")
async with aiohttp.request(
"POST",
f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session",
json=db_info,
) as resp:
assert resp.ok
@pytest.fixture(scope="function")

View File

@@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy):
@pytest.mark.asyncio
async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
@pytest.mark.parametrize("use_legacy_mgmt_api", [True, False])
async def test_link_auth(
vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool
):
"""
Check the Link auth flow: a lightweight auth method which delegates
all necessary checks to the console by sending client an auth URL.
@@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
link = await NeonProxy.find_auth_link(base_uri, psql)
psql_session_id = NeonProxy.get_session_id(base_uri, link)
await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id)
await NeonProxy.activate_link_auth(
vanilla_pg,
link_proxy,
psql_session_id,
use_legacy_mgmt_api=use_legacy_mgmt_api,
)
assert psql.stdout is not None
out = (await psql.stdout.read()).decode("utf-8").strip()