mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
[proxy] Pass extra parameters to the console (#2467)
With this change we now pass additional params to the console's auth methods.
This commit is contained in:
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -2283,6 +2283,7 @@ dependencies = [
|
||||
"tokio-rustls",
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
]
|
||||
@@ -3663,6 +3664,10 @@ name = "uuid"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
@@ -3953,6 +3958,7 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -11,13 +11,14 @@ bstr = "0.2.17"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
clap = "3.0"
|
||||
futures = "0.3.13"
|
||||
git-version = "0.3.5"
|
||||
hashbrown = "0.12"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.12.1"
|
||||
hyper = "0.14"
|
||||
itertools = "0.10.3"
|
||||
once_cell = "1.13.0"
|
||||
md5 = "0.7.0"
|
||||
once_cell = "1.13.0"
|
||||
parking_lot = "0.12"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
@@ -35,14 +36,13 @@ tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-rustls = "0.23.0"
|
||||
url = "2.2.2"
|
||||
git-version = "0.3.5"
|
||||
uuid = { version = "0.8.2", features = ["v4", "serde"]}
|
||||
x509-parser = "0.13.2"
|
||||
|
||||
utils = { path = "../libs/utils" }
|
||||
metrics = { path = "../libs/metrics" }
|
||||
workspace_hack = { version = "0.1", path = "../workspace_hack" }
|
||||
|
||||
x509-parser = "0.13.2"
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen = "0.8.14"
|
||||
rstest = "0.12"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::{BackendType, DatabaseInfo};
|
||||
pub use backend::{BackendType, ConsoleReqExtra, DatabaseInfo};
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
@@ -8,13 +8,12 @@ pub use console::{GetAuthInfoError, WakeComputeError};
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute, config, mgmt,
|
||||
stream::PqStream,
|
||||
compute, http, mgmt, stream, url,
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
|
||||
@@ -75,6 +74,14 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra<'a> {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
@@ -83,53 +90,83 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
/// * However, when we substitute `T` with [`ClientCredentials`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BackendType<T> {
|
||||
#[derive(Debug)]
|
||||
pub enum BackendType<'a, T> {
|
||||
/// Current Cloud API (V2).
|
||||
Console(T),
|
||||
Console(Cow<'a, http::Endpoint>, T),
|
||||
/// Local mock of Cloud API (V2).
|
||||
Postgres(T),
|
||||
Postgres(Cow<'a, url::ApiUrl>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link,
|
||||
Link(Cow<'a, url::ApiUrl>),
|
||||
}
|
||||
|
||||
impl<T> BackendType<T> {
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(endpoint, _) => fmt
|
||||
.debug_tuple("Console")
|
||||
.field(&endpoint.url().as_str())
|
||||
.finish(),
|
||||
Postgres(endpoint, _) => fmt
|
||||
.debug_tuple("Postgres")
|
||||
.field(&endpoint.as_str())
|
||||
.finish(),
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> BackendType<'_, T> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(Cow::Borrowed(c), x),
|
||||
Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
|
||||
Link(c) => Link(Cow::Borrowed(c)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> BackendType<'a, T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<R> {
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(x) => Console(f(x)),
|
||||
Postgres(x) => Postgres(f(x)),
|
||||
Link => Link,
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Postgres(c, x) => Postgres(c, f(x)),
|
||||
Link(c) => Link(c),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> BackendType<Result<T, E>> {
|
||||
impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<T>, E> {
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(x) => x.map(Console),
|
||||
Postgres(x) => x.map(Postgres),
|
||||
Link => Ok(Link),
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Postgres(c, x) => x.map(|x| Postgres(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<ClientCredentials<'_>> {
|
||||
impl BackendType<'_, ClientCredentials<'_>> {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
urls: &config::AuthUrls,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
use BackendType::*;
|
||||
|
||||
if let Console(creds) | Postgres(creds) = &mut self {
|
||||
if let Console(_, creds) | Postgres(_, creds) = &mut self {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -145,15 +182,13 @@ impl BackendType<ClientCredentials<'_>> {
|
||||
creds.project = Some(payload.project.into());
|
||||
|
||||
let mut config = match &self {
|
||||
Console(creds) => {
|
||||
console::Api::new(&urls.auth_endpoint, creds)
|
||||
Console(endpoint, creds) => {
|
||||
console::Api::new(endpoint, extra, creds)
|
||||
.wake_compute()
|
||||
.await?
|
||||
}
|
||||
Postgres(creds) => {
|
||||
postgres::Api::new(&urls.auth_endpoint, creds)
|
||||
.wake_compute()
|
||||
.await?
|
||||
Postgres(endpoint, creds) => {
|
||||
postgres::Api::new(endpoint, creds).wake_compute().await?
|
||||
}
|
||||
_ => unreachable!("see the patterns above"),
|
||||
};
|
||||
@@ -169,49 +204,18 @@ impl BackendType<ClientCredentials<'_>> {
|
||||
}
|
||||
|
||||
match self {
|
||||
Console(creds) => {
|
||||
console::Api::new(&urls.auth_endpoint, &creds)
|
||||
Console(endpoint, creds) => {
|
||||
console::Api::new(&endpoint, extra, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Postgres(creds) => {
|
||||
postgres::Api::new(&urls.auth_endpoint, &creds)
|
||||
Postgres(endpoint, creds) => {
|
||||
postgres::Api::new(&endpoint, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link => link::handle_user(&urls.auth_link_uri, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_backend_type_map() {
|
||||
let values = [
|
||||
BackendType::Console(0),
|
||||
BackendType::Postgres(0),
|
||||
BackendType::Link,
|
||||
];
|
||||
|
||||
for value in values {
|
||||
assert_eq!(value.map(|x| x), value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_type_transpose() {
|
||||
let values = [
|
||||
BackendType::Console(Ok::<_, ()>(0)),
|
||||
BackendType::Postgres(Ok(0)),
|
||||
BackendType::Link,
|
||||
];
|
||||
|
||||
for value in values {
|
||||
assert_eq!(value.map(Result::unwrap), value.transpose().unwrap());
|
||||
Link(url) => link::handle_user(&url, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Cloud API V2.
|
||||
|
||||
use super::ConsoleReqExtra;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute::{self, ComputeConnCfg},
|
||||
error::{io_error, UserFacingError},
|
||||
scram,
|
||||
http, scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::future::Future;
|
||||
@@ -120,14 +120,23 @@ pub enum AuthInfo {
|
||||
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
endpoint: &'a http::Endpoint,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
|
||||
Self { endpoint, creds }
|
||||
pub(super) fn new(
|
||||
endpoint: &'a http::Endpoint,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
creds: &'a ClientCredentials,
|
||||
) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
extra,
|
||||
creds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
@@ -139,16 +148,22 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project().expect("impossible"))
|
||||
.append_pair("role", self.creds.user);
|
||||
let req = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", uuid::Uuid::new_v4().to_string())
|
||||
.query(&[("session_id", self.extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", self.extra.application_name),
|
||||
("project", Some(self.creds.project().expect("impossible"))),
|
||||
("role", Some(self.creds.user)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
println!("cplane request: {}", req.url());
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await?;
|
||||
let resp = self.endpoint.execute(req).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(TransportError::HttpStatus(resp.status()).into());
|
||||
}
|
||||
@@ -162,15 +177,21 @@ impl<'a> Api<'a> {
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_wake_compute");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project().expect("impossible"));
|
||||
let req = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", uuid::Uuid::new_v4().to_string())
|
||||
.query(&[("session_id", self.extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", self.extra.application_name),
|
||||
("project", Some(self.creds.project().expect("impossible"))),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
println!("cplane request: {}", req.url());
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await?;
|
||||
let resp = self.endpoint.execute(req).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(TransportError::HttpStatus(resp.status()).into());
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ impl UserFacingError for LinkAuthError {
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"Welcome to Neon!\n",
|
||||
@@ -46,11 +46,11 @@ pub fn new_psql_session_id() -> String {
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
redirect_uri: &reqwest::Url,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(redirect_uri.as_str(), &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
let db_info = super::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
use crate::{auth, url::ApiUrl};
|
||||
use crate::auth;
|
||||
use anyhow::{ensure, Context};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<()>,
|
||||
pub auth_urls: AuthUrls,
|
||||
}
|
||||
|
||||
pub struct AuthUrls {
|
||||
pub auth_endpoint: ApiUrl,
|
||||
pub auth_link_uri: ApiUrl,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
}
|
||||
|
||||
pub struct TlsConfig {
|
||||
|
||||
@@ -1,27 +1,81 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use std::net::TcpListener;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
pub mod server;
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
use crate::url::ApiUrl;
|
||||
|
||||
/// Thin convenience wrapper for an API provided by an http endpoint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Endpoint {
|
||||
/// API's base URL.
|
||||
endpoint: ApiUrl,
|
||||
/// Connection manager with built-in pooling.
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let router = endpoint::make_router();
|
||||
router.get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("http has shut down");
|
||||
impl Endpoint {
|
||||
/// Construct a new HTTP endpoint wrapper.
|
||||
pub fn new(endpoint: ApiUrl, client: reqwest::Client) -> Self {
|
||||
Self { endpoint, client }
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
pub fn url(&self) -> &ApiUrl {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
/// Return a [builder](reqwest::RequestBuilder) for a `GET` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub async fn execute(
|
||||
&self,
|
||||
request: reqwest::Request,
|
||||
) -> Result<reqwest::Response, reqwest::Error> {
|
||||
self.client.execute(request).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn optional_query_params() -> anyhow::Result<()> {
|
||||
let url = "http://example.com".parse()?;
|
||||
let endpoint = Endpoint::new(url, reqwest::Client::new());
|
||||
|
||||
// Validate that this pattern makes sense.
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.query(&[
|
||||
("foo", Some("10")), // should be just `foo=10`
|
||||
("bar", None), // shouldn't be passed at all
|
||||
])
|
||||
.build()?;
|
||||
|
||||
assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uuid_params() -> anyhow::Result<()> {
|
||||
let url = "http://example.com".parse()?;
|
||||
let endpoint = Endpoint::new(url, reqwest::Client::new());
|
||||
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.query(&[("session_id", uuid::Uuid::nil())])
|
||||
.build()?;
|
||||
|
||||
assert_eq!(
|
||||
req.url().as_str(),
|
||||
"http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
27
proxy/src/http/server.rs
Normal file
27
proxy/src/http/server.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use std::net::TcpListener;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let router = endpoint::make_router();
|
||||
router.get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -23,7 +23,7 @@ use anyhow::{bail, Context};
|
||||
use clap::{self, Arg};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::{future::Future, net::SocketAddr};
|
||||
use std::{borrow::Cow, future::Future, net::SocketAddr};
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use utils::project_git_version;
|
||||
|
||||
@@ -36,23 +36,6 @@ async fn flatten_err(
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
}
|
||||
|
||||
/// A proper parser for auth backend parameter.
|
||||
impl clap::ValueEnum for auth::BackendType<()> {
|
||||
fn value_variants<'a>() -> &'a [Self] {
|
||||
use auth::BackendType::*;
|
||||
&[Console(()), Postgres(()), Link]
|
||||
}
|
||||
|
||||
fn to_possible_value<'a>(&self) -> Option<clap::PossibleValue<'a>> {
|
||||
use auth::BackendType::*;
|
||||
Some(clap::PossibleValue::new(match self {
|
||||
Console(_) => "console",
|
||||
Postgres(_) => "postgres",
|
||||
Link => "link",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let arg_matches = clap::App::new("Neon proxy/router")
|
||||
@@ -69,7 +52,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arg::new("auth-backend")
|
||||
.long("auth-backend")
|
||||
.takes_value(true)
|
||||
.value_parser(clap::builder::EnumValueParser::<auth::BackendType<()>>::new())
|
||||
.possible_values(["console", "postgres", "link"])
|
||||
.default_value("link"),
|
||||
)
|
||||
.arg(
|
||||
@@ -135,23 +118,30 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
|
||||
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
|
||||
|
||||
let auth_backend = *arg_matches
|
||||
.try_get_one::<auth::BackendType<()>>("auth-backend")?
|
||||
.unwrap();
|
||||
|
||||
let auth_urls = config::AuthUrls {
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
let auth_backend = match arg_matches.value_of("auth-backend").unwrap() {
|
||||
"console" => {
|
||||
let url = arg_matches.value_of("auth-endpoint").unwrap().parse()?;
|
||||
let endpoint = http::Endpoint::new(url, reqwest::Client::new());
|
||||
auth::BackendType::Console(Cow::Owned(endpoint), ())
|
||||
}
|
||||
"postgres" => {
|
||||
let url = arg_matches.value_of("auth-endpoint").unwrap().parse()?;
|
||||
auth::BackendType::Postgres(Cow::Owned(url), ())
|
||||
}
|
||||
"link" => {
|
||||
let url = arg_matches.value_of("uri").unwrap().parse()?;
|
||||
auth::BackendType::Link(Cow::Owned(url))
|
||||
}
|
||||
other => bail!("unsupported auth backend: {other}"),
|
||||
};
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
auth_urls,
|
||||
}));
|
||||
|
||||
println!("Version: {GIT_VERSION}");
|
||||
println!("Authentication backend: {:?}", config.auth_backend);
|
||||
println!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
println!("Starting http on {}", http_address);
|
||||
@@ -164,7 +154,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let tasks = [
|
||||
tokio::spawn(http::thread_main(http_listener)),
|
||||
tokio::spawn(http::server::thread_main(http_listener)),
|
||||
tokio::spawn(proxy::thread_main(config, proxy_listener)),
|
||||
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelMap};
|
||||
use crate::config::{AuthUrls, ProxyConfig, TlsConfig};
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
@@ -99,6 +99,7 @@ async fn handle_client(
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.transpose();
|
||||
|
||||
@@ -107,7 +108,7 @@ async fn handle_client(
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(&config.auth_urls, session))
|
||||
.with_session(|session| client.connect_to_db(session))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -179,7 +180,7 @@ struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
}
|
||||
@@ -188,7 +189,7 @@ impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -201,19 +202,22 @@ impl<'a, S> Client<'a, S> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
urls: &AuthUrls,
|
||||
session: cancellation::Session<'_>,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
} = self;
|
||||
|
||||
let extra = auth::ConsoleReqExtra {
|
||||
// Currently it's OK to generate a new UUID **here**, but
|
||||
// it might be better to move this to `cancellation::Session`.
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: params.get("application_name"),
|
||||
};
|
||||
|
||||
// Authenticate and connect to a compute node.
|
||||
let auth = creds.authenticate(urls, &mut stream).await;
|
||||
let auth = creds.authenticate(&extra, &mut stream).await;
|
||||
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
|
||||
let reported_auth_ok = node.reported_auth_ok;
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use anyhow::bail;
|
||||
use url::form_urlencoded::Serializer;
|
||||
|
||||
/// A [url](url::Url) type with additional guarantees.
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ApiUrl(url::Url);
|
||||
|
||||
impl ApiUrl {
|
||||
@@ -11,11 +11,6 @@ impl ApiUrl {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// See [`url::Url::query_pairs_mut`].
|
||||
pub fn query_pairs_mut(&mut self) -> Serializer<'_, url::UrlQuery<'_>> {
|
||||
self.0.query_pairs_mut()
|
||||
}
|
||||
|
||||
/// See [`url::Url::path_segments_mut`].
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
|
||||
// We've already verified that it works during construction.
|
||||
@@ -72,10 +67,7 @@ mod tests {
|
||||
let mut b = url.parse::<ApiUrl>().expect("unexpected parsing failure");
|
||||
|
||||
a.path_segments_mut().unwrap().push("method");
|
||||
a.query_pairs_mut().append_pair("key", "value");
|
||||
|
||||
b.path_segments_mut().push("method");
|
||||
b.query_pairs_mut().append_pair("key", "value");
|
||||
|
||||
assert_eq!(a, b.into_inner());
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc",
|
||||
tokio-util = { version = "0.7", features = ["codec", "io", "io-util", "tracing"] }
|
||||
tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] }
|
||||
tracing-core = { version = "0.1", features = ["once_cell", "std"] }
|
||||
uuid = { version = "0.8", features = ["getrandom", "serde", "std", "v4"] }
|
||||
|
||||
[build-dependencies]
|
||||
ahash = { version = "0.7", features = ["std"] }
|
||||
|
||||
Reference in New Issue
Block a user