From 94204dbbaedd0d8da6aa5b6d20cf08155ae6b331 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sat, 12 Oct 2024 07:18:13 +0200 Subject: [PATCH] [local_proxy]: install pg_session_jwt extension on demand --- proxy/src/auth/backend/local.rs | 13 +++++- proxy/src/bin/local_proxy.rs | 8 +++- proxy/src/compute_ctl/mod.rs | 51 ++++++++++++++++++++++++ proxy/src/http/mod.rs | 13 +++++- proxy/src/lib.rs | 1 + proxy/src/serverless/backend.rs | 38 +++++++++++++++--- proxy/src/serverless/local_conn_pool.rs | 53 ++++++++++++++++++------- 7 files changed, 151 insertions(+), 26 deletions(-) create mode 100644 proxy/src/compute_ctl/mod.rs diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index e3995ac6c0..1e029ff609 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -1,23 +1,32 @@ use std::net::SocketAddr; use arc_swap::ArcSwapOption; +use tokio::sync::Semaphore; use super::jwt::{AuthRule, FetchAuthRules}; use crate::auth::backend::jwt::FetchAuthRulesError; use crate::compute::ConnCfg; +use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}; use crate::control_plane::NodeInfo; use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag}; -use crate::EndpointId; +use crate::url::ApiUrl; +use crate::{http, EndpointId}; pub struct LocalBackend { + pub(crate) initialize: Semaphore, + pub(crate) compute_ctl: ComputeCtlApi, pub(crate) node_info: NodeInfo, } impl LocalBackend { - pub fn new(postgres_addr: SocketAddr) -> Self { + pub fn new(postgres_addr: SocketAddr, compute_ctl: ApiUrl) -> Self { LocalBackend { + initialize: Semaphore::new(1), + compute_ctl: ComputeCtlApi { + api: http::Endpoint::new(compute_ctl, http::new_client()), + }, node_info: NodeInfo { config: { let mut cfg = ConnCfg::new(); diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index e6bc369d9a..a16c288e5d 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -25,6 +25,7 @@ use proxy::rate_limiter::{ use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::{self, GlobalConnPoolOptions}; +use proxy::url::ApiUrl; use proxy::RoleName; project_git_version!(GIT_VERSION); @@ -80,7 +81,10 @@ struct LocalProxyCliArgs { connect_to_compute_retry: String, /// Address of the postgres server #[clap(long, default_value = "127.0.0.1:5432")] - compute: SocketAddr, + postgres: SocketAddr, + /// Address of the compute-ctl api service + #[clap(long, default_value = "http://127.0.0.1:3080/")] + compute_ctl: ApiUrl, /// Path of the local proxy config file #[clap(long, default_value = "./local_proxy.json")] config_path: Utf8PathBuf, @@ -295,7 +299,7 @@ fn build_auth_backend( args: &LocalProxyCliArgs, ) -> anyhow::Result<&'static auth::Backend<'static, ()>> { let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), + LocalBackend::new(args.postgres, args.compute_ctl.clone()), )); Ok(Box::leak(Box::new(auth_backend))) diff --git a/proxy/src/compute_ctl/mod.rs b/proxy/src/compute_ctl/mod.rs new file mode 100644 index 0000000000..77c75618c4 --- /dev/null +++ b/proxy/src/compute_ctl/mod.rs @@ -0,0 +1,51 @@ +use anyhow::Context; +use hyper::Method; +use typed_json::json; + +use crate::http; + +pub struct ComputeCtlApi { + pub(crate) api: http::Endpoint, +} + +// The following article is a stub. +// You can help Wikipedia by filling it out + +impl ComputeCtlApi { + pub async fn install_extension(&self, db: &str, ext: &str) -> anyhow::Result<()> { + self.api + .request_with_url(Method::POST, |url| { + url.path_segments_mut().push("extension"); + }) + .json(&json! {{ + "extension": ext, + "database": db, + }}) + .send() + .await + .context("connection error")? + .error_for_status() + .context("api error")?; + + Ok(()) + } + + pub async fn grant_role(&self, db: &str, role: &str, schema: &str) -> anyhow::Result<()> { + self.api + .request_with_url(Method::POST, |url| { + url.path_segments_mut().push("grant"); + }) + .json(&json! {{ + "schema": schema, + "role": role, + "database": db, + }}) + .send() + .await + .context("connection error")? + .error_for_status() + .context("api error")?; + + Ok(()) + } +} diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index fd587e8f01..f1b632e704 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::bail; use bytes::Bytes; +use http::Method; use http_body_util::BodyExt; use hyper::body::Body; pub(crate) use reqwest::{Request, Response}; @@ -93,9 +94,19 @@ impl Endpoint { /// Return a [builder](RequestBuilder) for a `GET` request, /// accepting a closure to modify the url path segments for more complex paths queries. pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder { + self.request_with_url(Method::GET, f) + } + + /// Return a [builder](RequestBuilder) for a request, + /// accepting a closure to modify the url path segments for more complex paths queries. + pub(crate) fn request_with_url( + &self, + method: Method, + f: impl for<'a> FnOnce(&'a mut ApiUrl), + ) -> RequestBuilder { let mut url = self.endpoint.clone(); f(&mut url); - self.client.get(url.into_inner()) + self.client.request(method, url.into_inner()) } /// Execute a [request](reqwest::Request). diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 74bc778a36..b5568ccea4 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -94,6 +94,7 @@ pub mod auth; pub mod cache; pub mod cancellation; pub mod compute; +pub mod compute_ctl; pub mod config; pub mod console_redirect_proxy; pub mod context; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index a180c4c2ed..5d5fa1f05a 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -34,6 +34,7 @@ pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, + pub(crate) config: &'static ProxyConfig, pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>, pub(crate) endpoint_rate_limiter: Arc, @@ -249,16 +250,41 @@ impl PoolingBackend { return Ok(client); } + let local_backend = match &self.auth_backend { + auth::Backend::ControlPlane(_, ()) => { + unreachable!("only local_proxy can connect to local postgres") + } + auth::Backend::Local(local) => local, + }; + + #[allow(unreachable_code, clippy::todo)] + if !self.local_pool.initialized(&conn_info) { + // only install and grant usage one at a time. + let _permit = local_backend.initialize.acquire().await.unwrap(); + + // check again for race + if !self.local_pool.initialized(&conn_info) { + local_backend + .compute_ctl + .install_extension(&conn_info.dbname, "pg_session_jwt") + .await + .map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?; + + local_backend + .compute_ctl + .grant_role(&conn_info.dbname, &conn_info.user_info.user, "auth") + .await + .map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?; + + self.local_pool.set_initialized(&conn_info); + } + } + let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = match &self.auth_backend { - auth::Backend::ControlPlane(_, ()) => { - unreachable!("only local_proxy can connect to local postgres") - } - auth::Backend::Local(local) => local.node_info.clone(), - }; + let mut node_info = local_backend.node_info.clone(); let (key, jwk) = create_random_jwk(); diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 5df37a8762..78bef1aa59 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -1,3 +1,14 @@ +//! Manages the pool of connections between local_proxy and postgres. +//! +//! The pool is keyed by database and role_name, and can contain multiple connections +//! shared between users. +//! +//! The pool manages the pg_session_jwt extension used for authorizing +//! requests in the db. +//! +//! The first time a db/role pair is seen, local_proxy attempts to install the extension +//! and grant usage to the role on the given schema. + use std::collections::HashMap; use std::pin::pin; use std::sync::{Arc, Weak}; @@ -32,9 +43,6 @@ struct ConnPoolEntry { _last_access: std::time::Instant, } -// /// key id for the pg_session_jwt state -// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1); - // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub(crate) struct EndpointConnPool { @@ -140,11 +148,18 @@ impl Drop for EndpointConnPool { pub(crate) struct DbUserConnPool { conns: Vec>, + + // true if we have definitely installed the extension and + // granted the role access to the auth schema. + initialized: bool, } impl Default for DbUserConnPool { fn default() -> Self { - Self { conns: Vec::new() } + Self { + conns: Vec::new(), + initialized: false, + } } } @@ -199,25 +214,16 @@ impl LocalConnPool { self.config.pool_options.idle_timeout } - // pub(crate) fn shutdown(&self) { - // let mut pool = self.global_pool.write(); - // pool.pools.clear(); - // pool.total_conns = 0; - // } - pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, ) -> Result>, HttpConnError> { - let mut client: Option> = None; - if let Some(entry) = self + let client = self .global_pool .write() .get_conn_entry(conn_info.db_and_user()) - { - client = Some(entry.conn); - } + .map(|entry| entry.conn); // ok return cached connection if found and establish a new one otherwise if let Some(client) = client { @@ -245,6 +251,23 @@ impl LocalConnPool { } Ok(None) } + + pub(crate) fn initialized(self: &Arc, conn_info: &ConnInfo) -> bool { + self.global_pool + .read() + .pools + .get(&conn_info.db_and_user()) + .map_or(false, |pool| pool.initialized) + } + + pub(crate) fn set_initialized(self: &Arc, conn_info: &ConnInfo) { + self.global_pool + .write() + .pools + .entry(conn_info.db_and_user()) + .or_default() + .initialized = true; + } } #[allow(clippy::too_many_arguments)]