From 284c84da5c78ddfb2fe46d86ec0b9ac1e7d9ed15 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 17 Oct 2024 10:01:53 +0100 Subject: [PATCH] refactor --- proxy/src/compute_ctl/mod.rs | 122 +++++++++++++++++------- proxy/src/serverless/backend.rs | 31 ++++-- proxy/src/serverless/local_conn_pool.rs | 4 + 3 files changed, 114 insertions(+), 43 deletions(-) diff --git a/proxy/src/compute_ctl/mod.rs b/proxy/src/compute_ctl/mod.rs index 77c75618c4..8dc7c98490 100644 --- a/proxy/src/compute_ctl/mod.rs +++ b/proxy/src/compute_ctl/mod.rs @@ -1,51 +1,101 @@ -use anyhow::Context; -use hyper::Method; -use typed_json::json; +use compute_api::responses::GenericAPIError; +use hyper::{Method, StatusCode}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use thiserror::Error; -use crate::http; +use crate::url::ApiUrl; +use crate::{http, DbName, RoleName}; pub struct ComputeCtlApi { pub(crate) api: http::Endpoint, } -// The following article is a stub. -// You can help Wikipedia by filling it out +#[derive(Serialize, Debug)] +pub struct ExtensionInstallRequest { + pub extension: &'static str, + pub database: DbName, + pub version: &'static str, +} + +#[derive(Serialize, Debug)] +pub struct SetRoleGrantsRequest { + pub database: DbName, + pub schema: &'static str, + pub privileges: Vec, + pub role: RoleName, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ExtensionInstallResponse {} + +#[derive(Clone, Debug, Deserialize)] +pub struct SetRoleGrantsResponse {} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[serde(rename_all = "UPPERCASE")] +pub enum Privilege { + Usage, +} + +#[derive(Error, Debug)] +pub enum ComputeCtlError { + #[error("connection error: {0}")] + ConnectionError(#[source] reqwest_middleware::Error), + #[error("request error [{status}]: {body:?}")] + RequestError { + status: StatusCode, + body: Option, + }, + #[error("response parsing error: {0}")] + ResonseError(#[source] reqwest::Error), +} impl ComputeCtlApi { - pub async fn install_extension(&self, db: &str, ext: &str) -> anyhow::Result<()> { - self.api - .request_with_url(Method::POST, |url| { - url.path_segments_mut().push("extension"); - }) - .json(&json! {{ - "extension": ext, - "database": db, - }}) - .send() - .await - .context("connection error")? - .error_for_status() - .context("api error")?; - - Ok(()) + pub async fn install_extension( + &self, + req: &ExtensionInstallRequest, + ) -> Result { + self.generic_request(req, Method::POST, |url| { + url.path_segments_mut().push("extensions"); + }) + .await } - pub async fn grant_role(&self, db: &str, role: &str, schema: &str) -> anyhow::Result<()> { - self.api - .request_with_url(Method::POST, |url| { - url.path_segments_mut().push("grant"); - }) - .json(&json! {{ - "schema": schema, - "role": role, - "database": db, - }}) + pub async fn grant_role( + &self, + req: &SetRoleGrantsRequest, + ) -> Result { + self.generic_request(req, Method::POST, |url| { + url.path_segments_mut().push("grants"); + }) + .await + } + + async fn generic_request( + &self, + req: &Req, + method: Method, + url: impl for<'a> FnOnce(&'a mut ApiUrl), + ) -> Result + where + Req: Serialize, + Resp: DeserializeOwned, + { + let resp = self + .api + .request_with_url(method, url) + .json(req) .send() .await - .context("connection error")? - .error_for_status() - .context("api error")?; + .map_err(ComputeCtlError::ConnectionError)?; - Ok(()) + let status = resp.status(); + if status.is_client_error() || status.is_server_error() { + let body = resp.json().await.ok(); + return Err(ComputeCtlError::RequestError { status, body }); + } + + resp.json().await.map_err(ComputeCtlError::ResonseError) } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 5d5fa1f05a..9d1f070b91 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -13,10 +13,13 @@ use tracing::{debug, info}; use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; use super::http_conn_pool::{self, poll_http2_client}; -use super::local_conn_pool::{self, LocalClient, LocalConnPool}; +use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, check_peer_addr_is_in_list, AuthError}; +use crate::compute_ctl::{ + ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, +}; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; @@ -266,15 +269,24 @@ impl PoolingBackend { if !self.local_pool.initialized(&conn_info) { local_backend .compute_ctl - .install_extension(&conn_info.dbname, "pg_session_jwt") - .await - .map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?; + .install_extension(&ExtensionInstallRequest { + extension: EXT_NAME, + database: conn_info.dbname.clone(), + // todo: move to const or config + version: EXT_VERSION, + }) + .await?; local_backend .compute_ctl - .grant_role(&conn_info.dbname, &conn_info.user_info.user, "auth") - .await - .map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?; + .grant_role(&SetRoleGrantsRequest { + // fixed for pg_session_jwt + schema: EXT_SCHEMA, + privileges: vec![Privilege::Usage], + database: conn_info.dbname.clone(), + role: conn_info.user_info.user.clone(), + }) + .await?; self.local_pool.set_initialized(&conn_info); } @@ -349,6 +361,8 @@ pub(crate) enum HttpConnError { #[error("could not parse JWT payload")] JwtPayloadError(serde_json::Error), + #[error("could not install extension: {0}")] + ComputeCtl(#[from] ComputeCtlError), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), #[error("user not authenticated")] @@ -373,6 +387,7 @@ impl ReportableError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => p.get_error_kind(), HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, + HttpConnError::ComputeCtl(_) => ErrorKind::Service, HttpConnError::JwtPayloadError(_) => ErrorKind::User, HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(), @@ -388,6 +403,7 @@ impl UserFacingError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), + HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(), HttpConnError::JwtPayloadError(p) => p.to_string(), HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(), @@ -404,6 +420,7 @@ impl CouldRetry for HttpConnError { match self { HttpConnError::PostgresConnectionError(e) => e.could_retry(), HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), + HttpConnError::ComputeCtl(_) => false, HttpConnError::ConnectionClosedAbruptly(_) => false, HttpConnError::JwtPayloadError(_) => false, HttpConnError::GetAuthInfo(_) => false, diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 78bef1aa59..976c2592b3 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -38,6 +38,10 @@ use crate::metrics::Metrics; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{DbName, RoleName}; +pub(crate) const EXT_NAME: &str = "pg_session_jwt"; +pub(crate) const EXT_VERSION: &str = "0.1.1"; +pub(crate) const EXT_SCHEMA: &str = "auth"; + struct ConnPoolEntry { conn: ClientInner, _last_access: std::time::Instant,