[local_proxy]: install pg_session_jwt extension on demand (#9370)

Follow up on #9344. We want to install the extension automatically. We
didn't want to couple the extension into compute_ctl so instead
local_proxy is the one to issue requests specific to the extension.

depends on #9344 and #9395
This commit is contained in:
Conrad Ludgate
2024-10-18 14:41:21 +01:00
committed by GitHub
parent ec6d3422a5
commit 5cbdec9c79
8 changed files with 222 additions and 29 deletions

View File

@@ -14,10 +14,13 @@ use tracing::{debug, info};
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::local_conn_pool::{self, LocalClient, LocalConnPool};
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
@@ -35,6 +38,7 @@ pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool<Send>>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -250,16 +254,47 @@ impl PoolingBackend {
return Ok(client);
}
let local_backend = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local,
};
if !self.local_pool.initialized(&conn_info) {
// only install and grant usage one at a time.
let _permit = local_backend.initialize.acquire().await.unwrap();
// check again for race
if !self.local_pool.initialized(&conn_info) {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
extension: EXT_NAME,
database: conn_info.dbname.clone(),
version: EXT_VERSION,
})
.await?;
local_backend
.compute_ctl
.grant_role(&SetRoleGrantsRequest {
schema: EXT_SCHEMA,
privileges: vec![Privilege::Usage],
database: conn_info.dbname.clone(),
role: conn_info.user_info.user.clone(),
})
.await?;
self.local_pool.set_initialized(&conn_info);
}
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local.node_info.clone(),
};
let mut node_info = local_backend.node_info.clone();
let (key, jwk) = create_random_jwk();
@@ -324,6 +359,8 @@ pub(crate) enum HttpConnError {
#[error("could not parse JWT payload")]
JwtPayloadError(serde_json::Error),
#[error("could not install extension: {0}")]
ComputeCtl(#[from] ComputeCtlError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
#[error("user not authenticated")]
@@ -348,6 +385,7 @@ impl ReportableError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
@@ -363,6 +401,7 @@ impl UserFacingError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
HttpConnError::JwtPayloadError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
@@ -379,6 +418,7 @@ impl CouldRetry for HttpConnError {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,

View File

@@ -1,3 +1,14 @@
//! Manages the pool of connections between local_proxy and postgres.
//!
//! The pool is keyed by database and role_name, and can contain multiple connections
//! shared between users.
//!
//! The pool manages the pg_session_jwt extension used for authorizing
//! requests in the db.
//!
//! The first time a db/role pair is seen, local_proxy attempts to install the extension
//! and grant usage to the role on the given schema.
use std::collections::HashMap;
use std::pin::pin;
use std::sync::{Arc, Weak};
@@ -27,14 +38,15 @@ use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{DbName, RoleName};
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
pub(crate) const EXT_VERSION: &str = "0.1.1";
pub(crate) const EXT_SCHEMA: &str = "auth";
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// /// key id for the pg_session_jwt state
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
@@ -140,11 +152,18 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
// true if we have definitely installed the extension and
// granted the role access to the auth schema.
initialized: bool,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
Self {
conns: Vec::new(),
initialized: false,
}
}
}
@@ -199,25 +218,16 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
self.config.pool_options.idle_timeout
}
// pub(crate) fn shutdown(&self) {
// let mut pool = self.global_pool.write();
// pool.pools.clear();
// pool.total_conns = 0;
// }
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<LocalClient<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
if let Some(entry) = self
let client = self
.global_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
.map(|entry| entry.conn);
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
@@ -245,6 +255,23 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
}
Ok(None)
}
pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
self.global_pool
.read()
.pools
.get(&conn_info.db_and_user())
.map_or(false, |pool| pool.initialized)
}
pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
self.global_pool
.write()
.pools
.entry(conn_info.db_and_user())
.or_default()
.initialized = true;
}
}
#[allow(clippy::too_many_arguments)]