diff --git a/compute/Dockerfile.compute-node b/compute/Dockerfile.compute-node index 45c1fd9f38..74970696b5 100644 --- a/compute/Dockerfile.compute-node +++ b/compute/Dockerfile.compute-node @@ -975,8 +975,8 @@ ARG PG_VERSION RUN case "${PG_VERSION}" in "v17") \ echo "pg_session_jwt does not yet have a release that supports pg17" && exit 0;; \ esac && \ - wget https://github.com/neondatabase/pg_session_jwt/archive/5aee2625af38213650e1a07ae038fdc427250ee4.tar.gz -O pg_session_jwt.tar.gz && \ - echo "5d91b10bc1347d36cffc456cb87bec25047935d6503dc652ca046f04760828e7 pg_session_jwt.tar.gz" | sha256sum --check && \ + wget https://github.com/neondatabase/pg_session_jwt/archive/e642528f429dd3f5403845a50191b78d434b84a6.tar.gz -O pg_session_jwt.tar.gz && \ + echo "1a69210703cc91224785e59a0a67562dd9eed9a0914ac84b11447582ca0d5b93 pg_session_jwt.tar.gz" | sha256sum --check && \ mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \ sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ cargo pgrx install --release diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index e3995ac6c0..1e029ff609 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -1,23 +1,32 @@ use std::net::SocketAddr; use arc_swap::ArcSwapOption; +use tokio::sync::Semaphore; use super::jwt::{AuthRule, FetchAuthRules}; use crate::auth::backend::jwt::FetchAuthRulesError; use crate::compute::ConnCfg; +use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}; use crate::control_plane::NodeInfo; use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag}; -use crate::EndpointId; +use crate::url::ApiUrl; +use crate::{http, EndpointId}; pub struct LocalBackend { + pub(crate) initialize: Semaphore, + pub(crate) compute_ctl: ComputeCtlApi, pub(crate) node_info: NodeInfo, } impl LocalBackend { - pub fn new(postgres_addr: SocketAddr) -> Self { + pub fn new(postgres_addr: SocketAddr, compute_ctl: ApiUrl) -> Self { LocalBackend { + initialize: Semaphore::new(1), + compute_ctl: ComputeCtlApi { + api: http::Endpoint::new(compute_ctl, http::new_client()), + }, node_info: NodeInfo { config: { let mut cfg = ConnCfg::new(); diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index e6bc369d9a..a16c288e5d 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -25,6 +25,7 @@ use proxy::rate_limiter::{ use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::{self, GlobalConnPoolOptions}; +use proxy::url::ApiUrl; use proxy::RoleName; project_git_version!(GIT_VERSION); @@ -80,7 +81,10 @@ struct LocalProxyCliArgs { connect_to_compute_retry: String, /// Address of the postgres server #[clap(long, default_value = "127.0.0.1:5432")] - compute: SocketAddr, + postgres: SocketAddr, + /// Address of the compute-ctl api service + #[clap(long, default_value = "http://127.0.0.1:3080/")] + compute_ctl: ApiUrl, /// Path of the local proxy config file #[clap(long, default_value = "./local_proxy.json")] config_path: Utf8PathBuf, @@ -295,7 +299,7 @@ fn build_auth_backend( args: &LocalProxyCliArgs, ) -> anyhow::Result<&'static auth::Backend<'static, ()>> { let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), + LocalBackend::new(args.postgres, args.compute_ctl.clone()), )); Ok(Box::leak(Box::new(auth_backend))) diff --git a/proxy/src/compute_ctl/mod.rs b/proxy/src/compute_ctl/mod.rs new file mode 100644 index 0000000000..2b57897223 --- /dev/null +++ b/proxy/src/compute_ctl/mod.rs @@ -0,0 +1,101 @@ +use compute_api::responses::GenericAPIError; +use hyper::{Method, StatusCode}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::url::ApiUrl; +use crate::{http, DbName, RoleName}; + +pub struct ComputeCtlApi { + pub(crate) api: http::Endpoint, +} + +#[derive(Serialize, Debug)] +pub struct ExtensionInstallRequest { + pub extension: &'static str, + pub database: DbName, + pub version: &'static str, +} + +#[derive(Serialize, Debug)] +pub struct SetRoleGrantsRequest { + pub database: DbName, + pub schema: &'static str, + pub privileges: Vec, + pub role: RoleName, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ExtensionInstallResponse {} + +#[derive(Clone, Debug, Deserialize)] +pub struct SetRoleGrantsResponse {} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[serde(rename_all = "UPPERCASE")] +pub enum Privilege { + Usage, +} + +#[derive(Error, Debug)] +pub enum ComputeCtlError { + #[error("connection error: {0}")] + ConnectionError(#[source] reqwest_middleware::Error), + #[error("request error [{status}]: {body:?}")] + RequestError { + status: StatusCode, + body: Option, + }, + #[error("response parsing error: {0}")] + ResponseError(#[source] reqwest::Error), +} + +impl ComputeCtlApi { + pub async fn install_extension( + &self, + req: &ExtensionInstallRequest, + ) -> Result { + self.generic_request(req, Method::POST, |url| { + url.path_segments_mut().push("extensions"); + }) + .await + } + + pub async fn grant_role( + &self, + req: &SetRoleGrantsRequest, + ) -> Result { + self.generic_request(req, Method::POST, |url| { + url.path_segments_mut().push("grants"); + }) + .await + } + + async fn generic_request( + &self, + req: &Req, + method: Method, + url: impl for<'a> FnOnce(&'a mut ApiUrl), + ) -> Result + where + Req: Serialize, + Resp: DeserializeOwned, + { + let resp = self + .api + .request_with_url(method, url) + .json(req) + .send() + .await + .map_err(ComputeCtlError::ConnectionError)?; + + let status = resp.status(); + if status.is_client_error() || status.is_server_error() { + let body = resp.json().await.ok(); + return Err(ComputeCtlError::RequestError { status, body }); + } + + resp.json().await.map_err(ComputeCtlError::ResponseError) + } +} diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index fd587e8f01..f1b632e704 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -8,6 +8,7 @@ use std::time::Duration; use anyhow::bail; use bytes::Bytes; +use http::Method; use http_body_util::BodyExt; use hyper::body::Body; pub(crate) use reqwest::{Request, Response}; @@ -93,9 +94,19 @@ impl Endpoint { /// Return a [builder](RequestBuilder) for a `GET` request, /// accepting a closure to modify the url path segments for more complex paths queries. pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder { + self.request_with_url(Method::GET, f) + } + + /// Return a [builder](RequestBuilder) for a request, + /// accepting a closure to modify the url path segments for more complex paths queries. + pub(crate) fn request_with_url( + &self, + method: Method, + f: impl for<'a> FnOnce(&'a mut ApiUrl), + ) -> RequestBuilder { let mut url = self.endpoint.clone(); f(&mut url); - self.client.get(url.into_inner()) + self.client.request(method, url.into_inner()) } /// Execute a [request](reqwest::Request). diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a7b3d45c95..ea17a88067 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -90,6 +90,7 @@ pub mod auth; pub mod cache; pub mod cancellation; pub mod compute; +pub mod compute_ctl; pub mod config; pub mod console_redirect_proxy; pub mod context; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 82e81dbcfe..5d59b4d252 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -14,10 +14,13 @@ use tracing::{debug, info}; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool}; use super::http_conn_pool::{self, poll_http2_client, Send}; -use super::local_conn_pool::{self, LocalClient, LocalConnPool}; +use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, check_peer_addr_is_in_list, AuthError}; +use crate::compute_ctl::{ + ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, +}; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; @@ -35,6 +38,7 @@ pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc>, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, + pub(crate) config: &'static ProxyConfig, pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>, pub(crate) endpoint_rate_limiter: Arc, @@ -250,16 +254,47 @@ impl PoolingBackend { return Ok(client); } + let local_backend = match &self.auth_backend { + auth::Backend::ControlPlane(_, ()) => { + unreachable!("only local_proxy can connect to local postgres") + } + auth::Backend::Local(local) => local, + }; + + if !self.local_pool.initialized(&conn_info) { + // only install and grant usage one at a time. + let _permit = local_backend.initialize.acquire().await.unwrap(); + + // check again for race + if !self.local_pool.initialized(&conn_info) { + local_backend + .compute_ctl + .install_extension(&ExtensionInstallRequest { + extension: EXT_NAME, + database: conn_info.dbname.clone(), + version: EXT_VERSION, + }) + .await?; + + local_backend + .compute_ctl + .grant_role(&SetRoleGrantsRequest { + schema: EXT_SCHEMA, + privileges: vec![Privilege::Usage], + database: conn_info.dbname.clone(), + role: conn_info.user_info.user.clone(), + }) + .await?; + + self.local_pool.set_initialized(&conn_info); + } + } + let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = match &self.auth_backend { - auth::Backend::ControlPlane(_, ()) => { - unreachable!("only local_proxy can connect to local postgres") - } - auth::Backend::Local(local) => local.node_info.clone(), - }; + let mut node_info = local_backend.node_info.clone(); let (key, jwk) = create_random_jwk(); @@ -324,6 +359,8 @@ pub(crate) enum HttpConnError { #[error("could not parse JWT payload")] JwtPayloadError(serde_json::Error), + #[error("could not install extension: {0}")] + ComputeCtl(#[from] ComputeCtlError), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), #[error("user not authenticated")] @@ -348,6 +385,7 @@ impl ReportableError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => p.get_error_kind(), HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, + HttpConnError::ComputeCtl(_) => ErrorKind::Service, HttpConnError::JwtPayloadError(_) => ErrorKind::User, HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(), @@ -363,6 +401,7 @@ impl UserFacingError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), + HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(), HttpConnError::JwtPayloadError(p) => p.to_string(), HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(), @@ -379,6 +418,7 @@ impl CouldRetry for HttpConnError { match self { HttpConnError::PostgresConnectionError(e) => e.could_retry(), HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), + HttpConnError::ComputeCtl(_) => false, HttpConnError::ConnectionClosedAbruptly(_) => false, HttpConnError::JwtPayloadError(_) => false, HttpConnError::GetAuthInfo(_) => false, diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index a01afd2820..beb2ad4e8f 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -1,3 +1,14 @@ +//! Manages the pool of connections between local_proxy and postgres. +//! +//! The pool is keyed by database and role_name, and can contain multiple connections +//! shared between users. +//! +//! The pool manages the pg_session_jwt extension used for authorizing +//! requests in the db. +//! +//! The first time a db/role pair is seen, local_proxy attempts to install the extension +//! and grant usage to the role on the given schema. + use std::collections::HashMap; use std::pin::pin; use std::sync::{Arc, Weak}; @@ -27,14 +38,15 @@ use crate::metrics::Metrics; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{DbName, RoleName}; +pub(crate) const EXT_NAME: &str = "pg_session_jwt"; +pub(crate) const EXT_VERSION: &str = "0.1.1"; +pub(crate) const EXT_SCHEMA: &str = "auth"; + struct ConnPoolEntry { conn: ClientInner, _last_access: std::time::Instant, } -// /// key id for the pg_session_jwt state -// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1); - // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub(crate) struct EndpointConnPool { @@ -140,11 +152,18 @@ impl Drop for EndpointConnPool { pub(crate) struct DbUserConnPool { conns: Vec>, + + // true if we have definitely installed the extension and + // granted the role access to the auth schema. + initialized: bool, } impl Default for DbUserConnPool { fn default() -> Self { - Self { conns: Vec::new() } + Self { + conns: Vec::new(), + initialized: false, + } } } @@ -199,25 +218,16 @@ impl LocalConnPool { self.config.pool_options.idle_timeout } - // pub(crate) fn shutdown(&self) { - // let mut pool = self.global_pool.write(); - // pool.pools.clear(); - // pool.total_conns = 0; - // } - pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, ) -> Result>, HttpConnError> { - let mut client: Option> = None; - if let Some(entry) = self + let client = self .global_pool .write() .get_conn_entry(conn_info.db_and_user()) - { - client = Some(entry.conn); - } + .map(|entry| entry.conn); // ok return cached connection if found and establish a new one otherwise if let Some(client) = client { @@ -245,6 +255,23 @@ impl LocalConnPool { } Ok(None) } + + pub(crate) fn initialized(self: &Arc, conn_info: &ConnInfo) -> bool { + self.global_pool + .read() + .pools + .get(&conn_info.db_and_user()) + .map_or(false, |pool| pool.initialized) + } + + pub(crate) fn set_initialized(self: &Arc, conn_info: &ConnInfo) { + self.global_pool + .write() + .pools + .entry(conn_info.db_and_user()) + .or_default() + .initialized = true; + } } #[allow(clippy::too_many_arguments)]