mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
[local_proxy]: install pg_session_jwt extension on demand
This commit is contained in:
@@ -1,23 +1,32 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
use crate::auth::backend::jwt::FetchAuthRulesError;
|
||||
use crate::compute::ConnCfg;
|
||||
use crate::compute_ctl::ComputeCtlApi;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo};
|
||||
use crate::control_plane::NodeInfo;
|
||||
use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag};
|
||||
use crate::EndpointId;
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{http, EndpointId};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub(crate) initialize: Semaphore,
|
||||
pub(crate) compute_ctl: ComputeCtlApi,
|
||||
pub(crate) node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
pub fn new(postgres_addr: SocketAddr, compute_ctl: ApiUrl) -> Self {
|
||||
LocalBackend {
|
||||
initialize: Semaphore::new(1),
|
||||
compute_ctl: ComputeCtlApi {
|
||||
api: http::Endpoint::new(compute_ctl, http::new_client()),
|
||||
},
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
|
||||
@@ -25,6 +25,7 @@ use proxy::rate_limiter::{
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::{self, GlobalConnPoolOptions};
|
||||
use proxy::url::ApiUrl;
|
||||
use proxy::RoleName;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
@@ -80,7 +81,10 @@ struct LocalProxyCliArgs {
|
||||
connect_to_compute_retry: String,
|
||||
/// Address of the postgres server
|
||||
#[clap(long, default_value = "127.0.0.1:5432")]
|
||||
compute: SocketAddr,
|
||||
postgres: SocketAddr,
|
||||
/// Address of the compute-ctl api service
|
||||
#[clap(long, default_value = "http://127.0.0.1:3080/")]
|
||||
compute_ctl: ApiUrl,
|
||||
/// Path of the local proxy config file
|
||||
#[clap(long, default_value = "./local_proxy.json")]
|
||||
config_path: Utf8PathBuf,
|
||||
@@ -295,7 +299,7 @@ fn build_auth_backend(
|
||||
args: &LocalProxyCliArgs,
|
||||
) -> anyhow::Result<&'static auth::Backend<'static, ()>> {
|
||||
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(args.compute),
|
||||
LocalBackend::new(args.postgres, args.compute_ctl.clone()),
|
||||
));
|
||||
|
||||
Ok(Box::leak(Box::new(auth_backend)))
|
||||
|
||||
51
proxy/src/compute_ctl/mod.rs
Normal file
51
proxy/src/compute_ctl/mod.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use anyhow::Context;
|
||||
use hyper::Method;
|
||||
use typed_json::json;
|
||||
|
||||
use crate::http;
|
||||
|
||||
pub struct ComputeCtlApi {
|
||||
pub(crate) api: http::Endpoint,
|
||||
}
|
||||
|
||||
// The following article is a stub.
|
||||
// You can help Wikipedia by filling it out
|
||||
|
||||
impl ComputeCtlApi {
|
||||
pub async fn install_extension(&self, db: &str, ext: &str) -> anyhow::Result<()> {
|
||||
self.api
|
||||
.request_with_url(Method::POST, |url| {
|
||||
url.path_segments_mut().push("extension");
|
||||
})
|
||||
.json(&json! {{
|
||||
"extension": ext,
|
||||
"database": db,
|
||||
}})
|
||||
.send()
|
||||
.await
|
||||
.context("connection error")?
|
||||
.error_for_status()
|
||||
.context("api error")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn grant_role(&self, db: &str, role: &str, schema: &str) -> anyhow::Result<()> {
|
||||
self.api
|
||||
.request_with_url(Method::POST, |url| {
|
||||
url.path_segments_mut().push("grant");
|
||||
})
|
||||
.json(&json! {{
|
||||
"schema": schema,
|
||||
"role": role,
|
||||
"database": db,
|
||||
}})
|
||||
.send()
|
||||
.await
|
||||
.context("connection error")?
|
||||
.error_for_status()
|
||||
.context("api error")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Body;
|
||||
pub(crate) use reqwest::{Request, Response};
|
||||
@@ -93,9 +94,19 @@ impl Endpoint {
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// accepting a closure to modify the url path segments for more complex paths queries.
|
||||
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
|
||||
self.request_with_url(Method::GET, f)
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a request,
|
||||
/// accepting a closure to modify the url path segments for more complex paths queries.
|
||||
pub(crate) fn request_with_url(
|
||||
&self,
|
||||
method: Method,
|
||||
f: impl for<'a> FnOnce(&'a mut ApiUrl),
|
||||
) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
f(&mut url);
|
||||
self.client.get(url.into_inner())
|
||||
self.client.request(method, url.into_inner())
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
|
||||
@@ -94,6 +94,7 @@ pub mod auth;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod compute_ctl;
|
||||
pub mod config;
|
||||
pub mod console_redirect_proxy;
|
||||
pub mod context;
|
||||
|
||||
@@ -34,6 +34,7 @@ pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -249,16 +250,41 @@ impl PoolingBackend {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let local_backend = match &self.auth_backend {
|
||||
auth::Backend::ControlPlane(_, ()) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::Backend::Local(local) => local,
|
||||
};
|
||||
|
||||
#[allow(unreachable_code, clippy::todo)]
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
// only install and grant usage one at a time.
|
||||
let _permit = local_backend.initialize.acquire().await.unwrap();
|
||||
|
||||
// check again for race
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
local_backend
|
||||
.compute_ctl
|
||||
.install_extension(&conn_info.dbname, "pg_session_jwt")
|
||||
.await
|
||||
.map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?;
|
||||
|
||||
local_backend
|
||||
.compute_ctl
|
||||
.grant_role(&conn_info.dbname, &conn_info.user_info.user, "auth")
|
||||
.await
|
||||
.map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?;
|
||||
|
||||
self.local_pool.set_initialized(&conn_info);
|
||||
}
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = match &self.auth_backend {
|
||||
auth::Backend::ControlPlane(_, ()) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::Backend::Local(local) => local.node_info.clone(),
|
||||
};
|
||||
let mut node_info = local_backend.node_info.clone();
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
//! Manages the pool of connections between local_proxy and postgres.
|
||||
//!
|
||||
//! The pool is keyed by database and role_name, and can contain multiple connections
|
||||
//! shared between users.
|
||||
//!
|
||||
//! The pool manages the pg_session_jwt extension used for authorizing
|
||||
//! requests in the db.
|
||||
//!
|
||||
//! The first time a db/role pair is seen, local_proxy attempts to install the extension
|
||||
//! and grant usage to the role on the given schema.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::{Arc, Weak};
|
||||
@@ -32,9 +43,6 @@ struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// /// key id for the pg_session_jwt state
|
||||
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
@@ -140,11 +148,18 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
|
||||
// true if we have definitely installed the extension and
|
||||
// granted the role access to the auth schema.
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
Self {
|
||||
conns: Vec::new(),
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,25 +214,16 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
// pub(crate) fn shutdown(&self) {
|
||||
// let mut pool = self.global_pool.write();
|
||||
// pool.pools.clear();
|
||||
// pool.total_conns = 0;
|
||||
// }
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<LocalClient<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
if let Some(entry) = self
|
||||
let client = self
|
||||
.global_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
.map(|entry| entry.conn);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
@@ -245,6 +251,23 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
|
||||
self.global_pool
|
||||
.read()
|
||||
.pools
|
||||
.get(&conn_info.db_and_user())
|
||||
.map_or(false, |pool| pool.initialized)
|
||||
}
|
||||
|
||||
pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
|
||||
self.global_pool
|
||||
.write()
|
||||
.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_default()
|
||||
.initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
Reference in New Issue
Block a user