proxy: subzero integration in auth-broker (embedded data-api) (#12474)

## Problem
We want to have the data-api served by the proxy directly instead of
relying on a 3rd party to run a deployment for each project/endpoint.

## Summary of changes
With the changes below, the proxy (auth-broker) becomes also a
"rest-broker", that can be thought of as a "Multi-tenant" data-api which
provides an automated REST api for all the databases in the region.

The core of the implementation (that leverages the subzero library) is
in proxy/src/serverless/rest.rs and this is the only place that has "new
logic".

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
Co-authored-by: Alexander Bayandin <alexander@neon.tech>
Co-authored-by: Conrad Ludgate <conrad@neon.tech>
This commit is contained in:
Ruslan Talpa
2025-07-21 21:16:28 +03:00
committed by GitHub
parent 187170be47
commit 0dbe551802
30 changed files with 2073 additions and 70 deletions

View File

@@ -20,6 +20,8 @@ use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
@@ -276,6 +278,13 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
accept_jwts: true,
console_redirect_confirmation_timeout: Duration::ZERO,
},
#[cfg(feature = "rest_broker")]
rest_config: RestConfig {
is_rest_broker: false,
db_schema_cache: None,
max_schema_size: 0,
hostname_prefix: String::new(),
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,

View File

@@ -31,6 +31,8 @@ use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
@@ -47,6 +49,8 @@ use crate::redis::{elasticache, notifications};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
@@ -246,11 +250,23 @@ struct ProxyCliArgs {
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
#[cfg(feature = "rest_broker")]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
#[cfg(feature = "rest_broker")]
db_schema_cache: String,
/// Maximum size allowed for schema in bytes
#[clap(long, default_value_t = 5 * 1024 * 1024)] // 5MB
#[cfg(feature = "rest_broker")]
max_schema_size: usize,
/// Hostname prefix to strip from request hostname to get database hostname
#[clap(long, default_value = "apirest.")]
#[cfg(feature = "rest_broker")]
hostname_prefix: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -517,6 +533,17 @@ pub async fn run() -> anyhow::Result<()> {
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
// add a task to flush the db_schema cache every 10 minutes
#[cfg(feature = "rest_broker")]
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
maintenance_tasks.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(600)).await;
db_schema_cache.flush();
}
});
}
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
@@ -679,6 +706,30 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout: Duration::from_secs(2),
};
#[cfg(feature = "rest_broker")]
let rest_config = {
let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?;
info!("Using DbSchemaCache with options={db_schema_cache_config:?}");
let db_schema_cache = if args.is_rest_broker {
Some(DbSchemaCache::new(
"db_schema_cache",
db_schema_cache_config.size,
db_schema_cache_config.ttl,
true,
))
} else {
None
};
RestConfig {
is_rest_broker: args.is_rest_broker,
db_schema_cache,
max_schema_size: args.max_schema_size,
hostname_prefix: args.hostname_prefix.clone(),
}
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -691,6 +742,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
#[cfg(feature = "rest_broker")]
rest_config,
};
let config = Box::leak(Box::new(config));

View File

@@ -204,6 +204,11 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
self.insert_raw_ttl(key, value, ttl, false);
}
#[cfg(feature = "rest_broker")]
pub(crate) fn insert(&self, key: K, value: V) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval);
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (_, old) = self.insert_raw(key.clone(), value);
@@ -214,6 +219,29 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
(old, cached)
}
#[cfg(feature = "rest_broker")]
pub(crate) fn flush(&self) {
let now = Instant::now();
let mut cache = self.cache.lock();
// Collect keys of expired entries first
let expired_keys: Vec<_> = cache
.iter()
.filter_map(|(key, entry)| {
if entry.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
// Remove expired entries
for key in expired_keys {
cache.remove(&key);
}
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {

View File

@@ -22,6 +22,8 @@ use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::{Host, RoleName};
@@ -30,6 +32,8 @@ pub struct ProxyConfig {
pub metric_collection: Option<MetricCollectionConfig>,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
#[cfg(feature = "rest_broker")]
pub rest_config: RestConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
@@ -80,6 +84,14 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[cfg(feature = "rest_broker")]
pub struct RestConfig {
pub is_rest_broker: bool,
pub db_schema_cache: Option<DbSchemaCache>,
pub max_schema_size: usize,
pub hostname_prefix: String,
}
#[derive(Debug)]
pub struct MetricBackupCollectionConfig {
pub remote_storage_config: Option<RemoteStorageConfig>,

View File

@@ -10,6 +10,7 @@ use super::connection_with_credentials_provider::ConnectionWithCredentialsProvid
use crate::cache::project_info::ProjectInfoCache;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use crate::util::deserialize_json_string;
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
@@ -121,15 +122,6 @@ struct InvalidateRole {
role_name: RoleNameInt,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
// https://github.com/serde-rs/serde/issues/1714
fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
where

View File

@@ -11,6 +11,8 @@ mod http_conn_pool;
mod http_util;
mod json;
mod local_conn_pool;
#[cfg(feature = "rest_broker")]
pub mod rest;
mod sql_over_http;
mod websocket;
@@ -487,6 +489,42 @@ async fn request_handler(
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
#[cfg(feature = "rest_broker")]
{
if config.rest_config.is_rest_broker
// we are testing for the path to be /database_name/rest/...
&& request
.uri()
.path()
.split('/')
.nth(2)
.is_some_and(|part| part.starts_with("rest"))
{
let ctx =
RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request
.headers()
.get("X-Neon-Query-ID")
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string());
if let Some(query_id) = testodrome_id {
info!(parent: &span, "testodrome query ID: {query_id}");
ctx.set_testodrome_id(query_id.into());
}
rest::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
#[cfg(not(feature = "rest_broker"))]
{
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
}

1165
proxy/src/serverless/rest.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -64,7 +64,7 @@ enum Payload {
Batch(BatchQueryData),
}
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where

View File

@@ -20,3 +20,13 @@ pub async fn run_until<F1: Future, F2: Future>(
Either::Right((f2, _)) => Err(f2),
}
}
pub fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}