mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-03 19:42:55 +00:00
local_proxy: integrate with pg_session_jwt extension (#9086)
This commit is contained in:
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -1820,6 +1820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47"
|
||||
dependencies = [
|
||||
"base16ct 0.2.0",
|
||||
"base64ct",
|
||||
"crypto-bigint 0.5.5",
|
||||
"digest",
|
||||
"ff 0.13.0",
|
||||
@@ -1829,6 +1830,8 @@ dependencies = [
|
||||
"pkcs8 0.10.2",
|
||||
"rand_core 0.6.4",
|
||||
"sec1 0.7.3",
|
||||
"serde_json",
|
||||
"serdect",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
@@ -4037,6 +4040,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5256,6 +5261,7 @@ dependencies = [
|
||||
"der 0.7.8",
|
||||
"generic-array",
|
||||
"pkcs8 0.10.2",
|
||||
"serdect",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
@@ -5510,6 +5516,16 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serdect"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177"
|
||||
dependencies = [
|
||||
"base16ct 0.2.0",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
@@ -7302,6 +7318,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"postgres-types",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
@@ -7326,6 +7343,7 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"toml_edit",
|
||||
|
||||
@@ -77,7 +77,7 @@ subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
@@ -101,7 +101,7 @@ jose-jwa = "0.1.2"
|
||||
jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
|
||||
signature = "2"
|
||||
ecdsa = "0.16"
|
||||
p256 = "0.13"
|
||||
p256 = { version = "0.13", features = ["jwk"] }
|
||||
rsa = "0.9"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::{
|
||||
RoleName,
|
||||
};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
|
||||
// TODO(conrad): make these configurable.
|
||||
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
|
||||
const MIN_RENEW: Duration = Duration::from_secs(30);
|
||||
@@ -241,7 +243,7 @@ impl JwkCacheEntryLock {
|
||||
endpoint: EndpointId,
|
||||
role_name: &RoleName,
|
||||
fetch: &F,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<ComputeCredentialKeys, anyhow::Error> {
|
||||
// JWT compact form is defined to be
|
||||
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
|
||||
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
|
||||
@@ -300,9 +302,9 @@ impl JwkCacheEntryLock {
|
||||
key => bail!("unsupported key type {key:?}"),
|
||||
};
|
||||
|
||||
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
|
||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
||||
@@ -327,7 +329,7 @@ impl JwkCacheEntryLock {
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(ComputeCredentialKeys::JwtPayload(payloadb))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,7 +341,7 @@ impl JwkCache {
|
||||
role_name: &RoleName,
|
||||
fetch: &F,
|
||||
jwt: &str,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<ComputeCredentialKeys, anyhow::Error> {
|
||||
// try with just a read lock first
|
||||
let key = (endpoint.clone(), role_name.clone());
|
||||
let entry = self.map.get(&key).as_deref().map(Arc::clone);
|
||||
|
||||
@@ -175,10 +175,12 @@ impl ComputeUserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
JwtPayload(Vec<u8>),
|
||||
None,
|
||||
}
|
||||
|
||||
|
||||
@@ -309,7 +309,7 @@ impl NodeInfo {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::None => &mut self.config,
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@ use std::{io, sync::Arc, time::Duration};
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tracing::{debug, field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
|
||||
check_peer_addr_is_in_list, AuthError,
|
||||
},
|
||||
@@ -32,10 +34,12 @@ use crate::{
|
||||
use super::{
|
||||
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||
http_conn_pool::{self, poll_http2_client},
|
||||
local_conn_pool::{self, LocalClient, LocalConnPool},
|
||||
};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -112,7 +116,7 @@ impl PoolingBackend {
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::Backend::ControlPlane(console, ()) => {
|
||||
config
|
||||
@@ -127,13 +131,16 @@ impl PoolingBackend {
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
}
|
||||
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
let keys = config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
@@ -145,8 +152,10 @@ impl PoolingBackend {
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||
Ok(())
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -231,6 +240,77 @@ impl PoolingBackend {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Connect to postgres over localhost.
|
||||
///
|
||||
/// We expect postgres to be started here, so we won't do any retries.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called with a non-local_proxy backend.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = match &self.config.auth_backend {
|
||||
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::Backend::Local(local) => local.node_info.clone(),
|
||||
};
|
||||
|
||||
let config = node_info
|
||||
.config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
|
||||
let handle = local_conn_pool::poll_client(
|
||||
self.local_pool.clone(),
|
||||
ctx,
|
||||
conn_info,
|
||||
client,
|
||||
connection,
|
||||
conn_id,
|
||||
node_info.aux.clone(),
|
||||
);
|
||||
|
||||
let kid = handle.get_client().get_process_id() as i64;
|
||||
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
|
||||
|
||||
debug!(kid, ?jwk, "setting up backend session state");
|
||||
|
||||
// initiates the auth session
|
||||
handle
|
||||
.get_client()
|
||||
.query(
|
||||
"select auth.init($1, $2);",
|
||||
&[
|
||||
&kid as &(dyn ToSql + Sync),
|
||||
&tokio_postgres::types::Json(jwk),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(?kid, "backend session state init");
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -241,6 +321,8 @@ pub(crate) enum HttpConnError {
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
#[error("could not parse JWT payload")]
|
||||
JwtPayloadError(serde_json::Error),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
@@ -266,6 +348,7 @@ impl ReportableError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
@@ -280,6 +363,7 @@ impl UserFacingError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::JwtPayloadError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
@@ -296,6 +380,7 @@ impl CouldRetry for HttpConnError {
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::JwtPayloadError(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
HttpConnError::WakeCompute(_) => false,
|
||||
|
||||
544
proxy/src/serverless/local_conn_pool.rs
Normal file
544
proxy/src/serverless/local_conn_pool.rs
Normal file
@@ -0,0 +1,544 @@
|
||||
use futures::{future::poll_fn, Future};
|
||||
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use parking_lot::RwLock;
|
||||
use rand::rngs::OsRng;
|
||||
use serde_json::Value;
|
||||
use signature::Signer;
|
||||
use std::task::{ready, Poll};
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use typed_json::json;
|
||||
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, DbName, RoleName};
|
||||
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::{ClientInnerExt, ConnInfo};
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// /// key id for the pg_session_jwt state
|
||||
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
global_pool_size_max_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
pools
|
||||
.get_mut(&db_user)
|
||||
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
*total_conns -= removed;
|
||||
removed > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool.read().total_conns >= global_max_conn {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.inc();
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
fn drop(&mut self) {
|
||||
if self.total_conns > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.total_conns as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
*conns -= removed;
|
||||
removed
|
||||
}
|
||||
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
*conns -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
conn
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
global_pool: RwLock<EndpointConnPool<C>>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
max_conns: config.pool_options.max_conns_per_endpoint,
|
||||
global_pool_size_max_conns: config.pool_options.max_total_conns,
|
||||
}),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_idle_timeout(&self) -> Duration {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
// pub(crate) fn shutdown(&self) {
|
||||
// let mut pool = self.global_pool.write();
|
||||
// pool.pools.clear();
|
||||
// pool.total_conns = 0;
|
||||
// }
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<LocalClient<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
if let Some(entry) = self
|
||||
.global_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"local_pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(LocalClient::new(
|
||||
client,
|
||||
conn_info.clone(),
|
||||
Arc::downgrade(self),
|
||||
)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_client(
|
||||
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> LocalClient<tokio_postgres::Client> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
let pool = Arc::downgrade(&global_pool);
|
||||
let pool_clone = pool.clone();
|
||||
|
||||
let db_user = conn_info.db_and_user();
|
||||
let idle = global_pool.get_idle_timeout();
|
||||
let cancel = CancellationToken::new();
|
||||
let cancelled = cancel.clone().cancelled_owned();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let mut idle_timeout = pin!(tokio::time::sleep(idle));
|
||||
let mut cancelled = pin!(cancelled);
|
||||
|
||||
poll_fn(move |cx| {
|
||||
if cancelled.as_mut().poll(cx).is_ready() {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
|
||||
match rx.has_changed() {
|
||||
Ok(true) => {
|
||||
session_id = *rx.borrow_and_update();
|
||||
info!(%session_id, "changed session");
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
}
|
||||
Err(_) => {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// 5 minute idle connection timeout
|
||||
if idle_timeout.as_mut().poll(cx).is_ready() {
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
info!("connection idle");
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
// remove client from pool - should close the connection if it's idle.
|
||||
// does nothing if the client is currently checked-out and in-use
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("idle connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session_id, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
warn!(%session_id, "unknown message");
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!(%session_id, "connection error: {}", e);
|
||||
break
|
||||
}
|
||||
None => {
|
||||
info!("connection closed");
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(())
|
||||
}).await;
|
||||
|
||||
}
|
||||
.instrument(span));
|
||||
|
||||
let key = SigningKey::random(&mut OsRng);
|
||||
|
||||
let inner = ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
key,
|
||||
jti: 0,
|
||||
};
|
||||
LocalClient::new(inner, conn_info, pool_clone)
|
||||
}
|
||||
|
||||
struct ClientInner<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
// needed for pg_session_jwt state
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalClient<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(self) fn new(
|
||||
inner: ClientInner<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
pub(crate) fn key(&self) -> &SigningKey {
|
||||
let inner = &self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed");
|
||||
&inner.key
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalClient<tokio_postgres::Client> {
|
||||
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
|
||||
let inner = self
|
||||
.inner
|
||||
.as_mut()
|
||||
.expect("client inner should not be removed");
|
||||
inner.jti += 1;
|
||||
|
||||
let kid = inner.inner.get_process_id();
|
||||
let header = json!({"kid":kid}).to_string();
|
||||
|
||||
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
|
||||
.map_err(HttpConnError::JwtPayloadError)?;
|
||||
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
|
||||
let payload = Value::Object(payload).to_string();
|
||||
|
||||
debug!(
|
||||
kid,
|
||||
jti = inner.jti,
|
||||
?header,
|
||||
?payload,
|
||||
"signing new ephemeral JWT"
|
||||
);
|
||||
|
||||
let token = sign_jwt(&inner.key, header, payload);
|
||||
|
||||
// initiates the auth session
|
||||
inner.inner.simple_query("discard all").await?;
|
||||
inner
|
||||
.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(kid, jti = inner.jti, "user session state init");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
|
||||
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
|
||||
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
|
||||
|
||||
let message = format!("{header}.{payload}");
|
||||
let sig: Signature = sk.sign(message.as_bytes());
|
||||
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
|
||||
format!("{message}.{base64_sig}")
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(
|
||||
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
|
||||
);
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub fn get_client(&self) -> &C {
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner
|
||||
}
|
||||
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce()> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for LocalClient<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod local_conn_pool;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
@@ -63,6 +64,7 @@ pub async fn task_main(
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let conn_pool = Arc::clone(&conn_pool);
|
||||
@@ -105,6 +107,7 @@ pub async fn task_main(
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
local_pool,
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
|
||||
@@ -40,7 +40,7 @@ use url::Url;
|
||||
use urlencoding;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
@@ -56,20 +56,22 @@ use crate::metrics::Metrics;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::usage_metrics::MetricCounter;
|
||||
use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::LocalProxyConnError;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::http_util::json_response;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
use super::local_conn_pool;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -620,6 +622,9 @@ async fn handle_db_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let is_local_proxy =
|
||||
matches!(backend.config.auth_backend, crate::auth::Backend::Local(_));
|
||||
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
@@ -639,18 +644,24 @@ async fn handle_db_inner(
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
}
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let client = match keys.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
|
||||
client.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
_ => {
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
@@ -791,7 +802,7 @@ impl QueryData {
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
@@ -865,7 +876,7 @@ impl BatchQueryData {
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
@@ -1058,3 +1069,50 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
Ok((ready, results))
|
||||
}
|
||||
|
||||
enum Client {
|
||||
Remote(conn_pool::Client<tokio_postgres::Client>),
|
||||
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
enum Discard<'a> {
|
||||
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
fn metrics(&self) -> Arc<MetricCounter> {
|
||||
match self {
|
||||
Client::Remote(client) => client.metrics(),
|
||||
Client::Local(local_client) => local_client.metrics(),
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
match self {
|
||||
Client::Remote(client) => {
|
||||
let (c, d) = client.inner();
|
||||
(c, Discard::Remote(d))
|
||||
}
|
||||
Client::Local(local_client) => {
|
||||
let (c, d) = local_client.inner();
|
||||
(c, Discard::Local(d))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Discard<'_> {
|
||||
fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.check_idle(status),
|
||||
Discard::Local(discard) => discard.check_idle(status),
|
||||
}
|
||||
}
|
||||
fn discard(&mut self) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.discard(),
|
||||
Discard::Local(discard) => discard.discard(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ num-integer = { version = "0.1", features = ["i128"] }
|
||||
num-traits = { version = "0.2", features = ["i128", "libm"] }
|
||||
once_cell = { version = "1" }
|
||||
parquet = { version = "53", default-features = false, features = ["zstd"] }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", default-features = false, features = ["with-serde_json-1"] }
|
||||
prost = { version = "0.13", features = ["prost-derive"] }
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
regex = { version = "1" }
|
||||
@@ -66,7 +67,7 @@ regex-syntax = { version = "0.8" }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
serde_json = { version = "1", features = ["alloc", "raw_value"] }
|
||||
sha2 = { version = "0.10", features = ["asm", "oid"] }
|
||||
signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
@@ -76,6 +77,7 @@ sync_wrapper = { version = "0.1", default-features = false, features = ["futures
|
||||
tikv-jemalloc-sys = { version = "0.5" }
|
||||
time = { version = "0.3", features = ["macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", features = ["with-serde_json-1"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
|
||||
toml_edit = { version = "0.22", features = ["serde"] }
|
||||
|
||||
Reference in New Issue
Block a user