mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
refactor
This commit is contained in:
@@ -1,51 +1,101 @@
|
||||
use anyhow::Context;
|
||||
use hyper::Method;
|
||||
use typed_json::json;
|
||||
use compute_api::responses::GenericAPIError;
|
||||
use hyper::{Method, StatusCode};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::http;
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{http, DbName, RoleName};
|
||||
|
||||
pub struct ComputeCtlApi {
|
||||
pub(crate) api: http::Endpoint,
|
||||
}
|
||||
|
||||
// The following article is a stub.
|
||||
// You can help Wikipedia by filling it out
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ExtensionInstallRequest {
|
||||
pub extension: &'static str,
|
||||
pub database: DbName,
|
||||
pub version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct SetRoleGrantsRequest {
|
||||
pub database: DbName,
|
||||
pub schema: &'static str,
|
||||
pub privileges: Vec<Privilege>,
|
||||
pub role: RoleName,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ExtensionInstallResponse {}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct SetRoleGrantsResponse {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Privilege {
|
||||
Usage,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ComputeCtlError {
|
||||
#[error("connection error: {0}")]
|
||||
ConnectionError(#[source] reqwest_middleware::Error),
|
||||
#[error("request error [{status}]: {body:?}")]
|
||||
RequestError {
|
||||
status: StatusCode,
|
||||
body: Option<GenericAPIError>,
|
||||
},
|
||||
#[error("response parsing error: {0}")]
|
||||
ResonseError(#[source] reqwest::Error),
|
||||
}
|
||||
|
||||
impl ComputeCtlApi {
|
||||
pub async fn install_extension(&self, db: &str, ext: &str) -> anyhow::Result<()> {
|
||||
self.api
|
||||
.request_with_url(Method::POST, |url| {
|
||||
url.path_segments_mut().push("extension");
|
||||
})
|
||||
.json(&json! {{
|
||||
"extension": ext,
|
||||
"database": db,
|
||||
}})
|
||||
.send()
|
||||
.await
|
||||
.context("connection error")?
|
||||
.error_for_status()
|
||||
.context("api error")?;
|
||||
|
||||
Ok(())
|
||||
pub async fn install_extension(
|
||||
&self,
|
||||
req: &ExtensionInstallRequest,
|
||||
) -> Result<ExtensionInstallResponse, ComputeCtlError> {
|
||||
self.generic_request(req, Method::POST, |url| {
|
||||
url.path_segments_mut().push("extensions");
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn grant_role(&self, db: &str, role: &str, schema: &str) -> anyhow::Result<()> {
|
||||
self.api
|
||||
.request_with_url(Method::POST, |url| {
|
||||
url.path_segments_mut().push("grant");
|
||||
})
|
||||
.json(&json! {{
|
||||
"schema": schema,
|
||||
"role": role,
|
||||
"database": db,
|
||||
}})
|
||||
pub async fn grant_role(
|
||||
&self,
|
||||
req: &SetRoleGrantsRequest,
|
||||
) -> Result<SetRoleGrantsResponse, ComputeCtlError> {
|
||||
self.generic_request(req, Method::POST, |url| {
|
||||
url.path_segments_mut().push("grants");
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn generic_request<Req, Resp>(
|
||||
&self,
|
||||
req: &Req,
|
||||
method: Method,
|
||||
url: impl for<'a> FnOnce(&'a mut ApiUrl),
|
||||
) -> Result<Resp, ComputeCtlError>
|
||||
where
|
||||
Req: Serialize,
|
||||
Resp: DeserializeOwned,
|
||||
{
|
||||
let resp = self
|
||||
.api
|
||||
.request_with_url(method, url)
|
||||
.json(req)
|
||||
.send()
|
||||
.await
|
||||
.context("connection error")?
|
||||
.error_for_status()
|
||||
.context("api error")?;
|
||||
.map_err(ComputeCtlError::ConnectionError)?;
|
||||
|
||||
Ok(())
|
||||
let status = resp.status();
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
let body = resp.json().await.ok();
|
||||
return Err(ComputeCtlError::RequestError { status, body });
|
||||
}
|
||||
|
||||
resp.json().await.map_err(ComputeCtlError::ResonseError)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,10 +13,13 @@ use tracing::{debug, info};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, poll_http2_client};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
@@ -266,15 +269,24 @@ impl PoolingBackend {
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
local_backend
|
||||
.compute_ctl
|
||||
.install_extension(&conn_info.dbname, "pg_session_jwt")
|
||||
.await
|
||||
.map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?;
|
||||
.install_extension(&ExtensionInstallRequest {
|
||||
extension: EXT_NAME,
|
||||
database: conn_info.dbname.clone(),
|
||||
// todo: move to const or config
|
||||
version: EXT_VERSION,
|
||||
})
|
||||
.await?;
|
||||
|
||||
local_backend
|
||||
.compute_ctl
|
||||
.grant_role(&conn_info.dbname, &conn_info.user_info.user, "auth")
|
||||
.await
|
||||
.map_err(|_| HttpConnError::LocalProxyConnectionError(todo!()))?;
|
||||
.grant_role(&SetRoleGrantsRequest {
|
||||
// fixed for pg_session_jwt
|
||||
schema: EXT_SCHEMA,
|
||||
privileges: vec![Privilege::Usage],
|
||||
database: conn_info.dbname.clone(),
|
||||
role: conn_info.user_info.user.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.local_pool.set_initialized(&conn_info);
|
||||
}
|
||||
@@ -349,6 +361,8 @@ pub(crate) enum HttpConnError {
|
||||
#[error("could not parse JWT payload")]
|
||||
JwtPayloadError(serde_json::Error),
|
||||
|
||||
#[error("could not install extension: {0}")]
|
||||
ComputeCtl(#[from] ComputeCtlError),
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
#[error("user not authenticated")]
|
||||
@@ -373,6 +387,7 @@ impl ReportableError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
|
||||
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
@@ -388,6 +403,7 @@ impl UserFacingError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
|
||||
HttpConnError::JwtPayloadError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
@@ -404,6 +420,7 @@ impl CouldRetry for HttpConnError {
|
||||
match self {
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ComputeCtl(_) => false,
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::JwtPayloadError(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
|
||||
@@ -38,6 +38,10 @@ use crate::metrics::Metrics;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{DbName, RoleName};
|
||||
|
||||
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
|
||||
pub(crate) const EXT_VERSION: &str = "0.1.1";
|
||||
pub(crate) const EXT_SCHEMA: &str = "auth";
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
|
||||
Reference in New Issue
Block a user