mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-02 10:10:37 +00:00
Compare commits
1 Commits
hack/compu
...
proxy-kick
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35b77a138f |
@@ -29,8 +29,15 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
|
|||||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Console management API listener task.
|
/// Listener task for mgmt connections using the libpq protocol, for
|
||||||
/// It spawns console response handlers needed for the link auth.
|
/// the purposes of "kick callback" from proxy, for link authentication.
|
||||||
|
///
|
||||||
|
/// The protocol is the libpq protocol, but the only "query" we accept is a
|
||||||
|
/// JSON document. The JSON document is a KickSession serialized to JSON.
|
||||||
|
///
|
||||||
|
/// This is considered legacy now. The preferred way to deliver "kick
|
||||||
|
/// callbacks" is now via the HTTP API. See `server.rs`. Once the control
|
||||||
|
/// plane has switched to using the HTTP API, this can be removed.
|
||||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||||
scopeguard::defer! {
|
scopeguard::defer! {
|
||||||
info!("mgmt has shut down");
|
info!("mgmt has shut down");
|
||||||
@@ -44,11 +51,12 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
|||||||
.set_nodelay(true)
|
.set_nodelay(true)
|
||||||
.context("failed to set client socket option")?;
|
.context("failed to set client socket option")?;
|
||||||
|
|
||||||
|
// spawn a task to handle this connection
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let span = info_span!("mgmt", peer = %peer_addr);
|
let span = info_span!("mgmt", peer = %peer_addr);
|
||||||
let _enter = span.enter();
|
let _enter = span.enter();
|
||||||
|
|
||||||
info!("started a new console management API thread");
|
info!("started a new console management API task");
|
||||||
scopeguard::defer! {
|
scopeguard::defer! {
|
||||||
info!("console management API thread is about to finish");
|
info!("console management API thread is about to finish");
|
||||||
}
|
}
|
||||||
@@ -68,7 +76,6 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
|||||||
/// A message received by `mgmt` when a compute node is ready.
|
/// A message received by `mgmt` when a compute node is ready.
|
||||||
pub type ComputeReady = Result<DatabaseInfo, String>;
|
pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||||
|
|
||||||
// TODO: replace with an http-based protocol.
|
|
||||||
struct MgmtHandler;
|
struct MgmtHandler;
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
use anyhow::anyhow;
|
use crate::console;
|
||||||
use hyper::{Body, Request, Response, StatusCode};
|
use crate::console::messages::KickSession;
|
||||||
|
use crate::waiters::NotifyError;
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use hyper::{body, Body, Request, Response, StatusCode};
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||||
@@ -8,8 +11,30 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
|||||||
json_response(StatusCode::OK, "")
|
json_response(StatusCode::OK, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Process a session kick callback from the control plane. The body is a
|
||||||
|
/// KickSession as a JSON document.
|
||||||
|
///
|
||||||
|
/// TODO: authentication
|
||||||
|
async fn kick_session_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||||
|
let body = &body::to_bytes(req.into_body())
|
||||||
|
.await
|
||||||
|
.context("Failed to get request body")
|
||||||
|
.map_err(ApiError::BadRequest)?;
|
||||||
|
let kick_session_json: KickSession = serde_json::from_slice(body)
|
||||||
|
.context("Failed to parse query as json")
|
||||||
|
.map_err(ApiError::BadRequest)?;
|
||||||
|
|
||||||
|
match console::mgmt::notify(kick_session_json.session_id, Ok(kick_session_json.result)) {
|
||||||
|
Ok(()) => json_response(StatusCode::OK, ""),
|
||||||
|
Err(NotifyError::NotFound(s)) => Err(ApiError::NotFound(anyhow::anyhow!(s))),
|
||||||
|
Err(e @ NotifyError::Hangup) => Err(ApiError::NotFound(anyhow::anyhow!(e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||||
endpoint::make_router().get("/v1/status", status_handler)
|
endpoint::make_router()
|
||||||
|
.get("/v1/status", status_handler)
|
||||||
|
.post("/v1/kick_session", kick_session_handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from types import TracebackType
|
|||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import asyncpg
|
import asyncpg
|
||||||
import backoff # type: ignore
|
import backoff # type: ignore
|
||||||
import boto3
|
import boto3
|
||||||
@@ -2561,7 +2562,11 @@ class NeonProxy(PgProtocol):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def activate_link_auth(
|
async def activate_link_auth(
|
||||||
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True
|
local_vanilla_pg,
|
||||||
|
proxy_with_metric_collector,
|
||||||
|
psql_session_id,
|
||||||
|
create_user=True,
|
||||||
|
use_legacy_mgmt_api=False,
|
||||||
):
|
):
|
||||||
pg_user = "proxy"
|
pg_user = "proxy"
|
||||||
|
|
||||||
@@ -2570,33 +2575,40 @@ class NeonProxy(PgProtocol):
|
|||||||
local_vanilla_pg.start()
|
local_vanilla_pg.start()
|
||||||
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
|
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
|
||||||
|
|
||||||
db_info = json.dumps(
|
db_info = {
|
||||||
{
|
"session_id": psql_session_id,
|
||||||
"session_id": psql_session_id,
|
"result": {
|
||||||
"result": {
|
"Success": {
|
||||||
"Success": {
|
"host": local_vanilla_pg.default_options["host"],
|
||||||
"host": local_vanilla_pg.default_options["host"],
|
"port": local_vanilla_pg.default_options["port"],
|
||||||
"port": local_vanilla_pg.default_options["port"],
|
"dbname": local_vanilla_pg.default_options["dbname"],
|
||||||
"dbname": local_vanilla_pg.default_options["dbname"],
|
"user": pg_user,
|
||||||
"user": pg_user,
|
"aux": {
|
||||||
"aux": {
|
"project_id": "test_project_id",
|
||||||
"project_id": "test_project_id",
|
"endpoint_id": "test_endpoint_id",
|
||||||
"endpoint_id": "test_endpoint_id",
|
"branch_id": "test_branch_id",
|
||||||
"branch_id": "test_branch_id",
|
},
|
||||||
},
|
}
|
||||||
}
|
},
|
||||||
},
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("sending session activation message")
|
if use_legacy_mgmt_api:
|
||||||
psql = await PSQL(
|
log.info("sending session activation message using legacy libpq mgmt interface")
|
||||||
host=proxy_with_metric_collector.host,
|
psql = await PSQL(
|
||||||
port=proxy_with_metric_collector.mgmt_port,
|
host=proxy_with_metric_collector.host,
|
||||||
).run(db_info)
|
port=proxy_with_metric_collector.mgmt_port,
|
||||||
assert psql.stdout is not None
|
).run(json.dumps(db_info))
|
||||||
out = (await psql.stdout.read()).decode("utf-8").strip()
|
assert psql.stdout is not None
|
||||||
assert out == "ok"
|
out = (await psql.stdout.read()).decode("utf-8").strip()
|
||||||
|
assert out == "ok"
|
||||||
|
else:
|
||||||
|
log.info("sending session activation message using HTTP mgmt interface")
|
||||||
|
async with aiohttp.request(
|
||||||
|
"POST",
|
||||||
|
f"http://{proxy_with_metric_collector.host}:{proxy_with_metric_collector.http_port}/v1/kick_session",
|
||||||
|
json=db_info,
|
||||||
|
) as resp:
|
||||||
|
assert resp.ok
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ def test_password_hack(static_proxy: NeonProxy):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
|
@pytest.mark.parametrize("use_legacy_mgmt_api", [True, False])
|
||||||
|
async def test_link_auth(
|
||||||
|
vanilla_pg: VanillaPostgres, link_proxy: NeonProxy, use_legacy_mgmt_api: bool
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Check the Link auth flow: a lightweight auth method which delegates
|
Check the Link auth flow: a lightweight auth method which delegates
|
||||||
all necessary checks to the console by sending client an auth URL.
|
all necessary checks to the console by sending client an auth URL.
|
||||||
@@ -47,7 +50,12 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
|
|||||||
link = await NeonProxy.find_auth_link(base_uri, psql)
|
link = await NeonProxy.find_auth_link(base_uri, psql)
|
||||||
|
|
||||||
psql_session_id = NeonProxy.get_session_id(base_uri, link)
|
psql_session_id = NeonProxy.get_session_id(base_uri, link)
|
||||||
await NeonProxy.activate_link_auth(vanilla_pg, link_proxy, psql_session_id)
|
await NeonProxy.activate_link_auth(
|
||||||
|
vanilla_pg,
|
||||||
|
link_proxy,
|
||||||
|
psql_session_id,
|
||||||
|
use_legacy_mgmt_api=use_legacy_mgmt_api,
|
||||||
|
)
|
||||||
|
|
||||||
assert psql.stdout is not None
|
assert psql.stdout is not None
|
||||||
out = (await psql.stdout.read()).decode("utf-8").strip()
|
out = (await psql.stdout.read()).decode("utf-8").strip()
|
||||||
|
|||||||
Reference in New Issue
Block a user