subzero pre-integration refactor (#12416)

## Problem
integrating subzero requires a bit of refactoring. To make the
integration PR a bit more manageable, the refactoring is done in this
separate PR.
 
## Summary of changes
* move common types/functions used in sql_over_http to errors.rs and
http_util.rs
* add the "Local" auth backend to proxy (similar to local_proxy), useful
in local testing
* change the Connect and Send type for the http client to allow for
custom body when making post requests to local_proxy from the proxy

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
This commit is contained in:
Ruslan Talpa
2025-07-03 14:04:08 +03:00
committed by GitHub
parent 1bc1eae5e8
commit 95e1011cd6
12 changed files with 581 additions and 412 deletions

View File

@@ -138,3 +138,62 @@ Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full"
```
## auth broker setup:
Create a postgres instance:
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_HOST_AUTH_METHOD=trust \
--env POSTGRES_USER=authenticated \
--env POSTGRES_DB=database \
--publish 5432:5432 \
postgres:17-bookworm
```
Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs)
```sh
{
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json",
"provider_name": "foo",
"jwt_audience": null
}
]
}
```
Start the local proxy:
```sh
cargo run --bin local_proxy -- \
--disable_pg_session_jwt true \
--http 0.0.0.0:7432
```
Start the auth broker:
```sh
LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \
-c server.crt -k server.key \
--is-auth-broker true \
--wss 0.0.0.0:8080 \
--http 0.0.0.0:7002 \
--auth-backend local
```
Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable.
```sh
export NEON_JWT="..."
```
Run a query against the auth broker:
```sh
curl -k "https://foo.local.neon.build:8080/sql" \
-H "Authorization: Bearer $NEON_JWT" \
-H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \
-d '{"query":"select 1","params":[]}'
```

View File

@@ -171,7 +171,6 @@ impl ComputeUserInfo {
pub(crate) enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {

View File

@@ -1,43 +1,39 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use anyhow::bail;
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use camino::Utf8PathBuf;
use clap::Parser;
use compute_api::spec::LocalProxySpec;
use futures::future::Either;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
use crate::tls::client_config::compute_client_config_with_root_certs;
use crate::types::RoleName;
use crate::url::ApiUrl;
project_git_version!(GIT_VERSION);
@@ -82,6 +78,11 @@ struct LocalProxyCliArgs {
/// Path of the local proxy PID file
#[clap(long, default_value = "./local_proxy.pid")]
pid_path: Utf8PathBuf,
/// Disable pg_session_jwt extension installation
/// This is useful for testing the local proxy with vanilla postgres.
#[clap(long, default_value = "false")]
#[cfg(feature = "testing")]
disable_pg_session_jwt: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -282,6 +283,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: args.disable_pg_session_jwt,
})))
}
@@ -293,132 +296,3 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati
Box::leak(Box::new(auth_backend))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -22,9 +22,13 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
#[cfg(any(test, feature = "testing"))]
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -43,6 +47,10 @@ use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -60,6 +68,9 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[cfg(any(test, feature = "testing"))]
Local,
}
/// Neon proxy/router
@@ -74,6 +85,10 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
#[cfg(any(test, feature = "testing"))]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -226,6 +241,14 @@ struct ProxyCliArgs {
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
db_schema_cache: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -386,6 +409,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
#[cfg(any(test, feature = "testing"))]
let refresh_config_notify = Arc::new(Notify::new());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -412,6 +437,17 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
#[cfg(any(test, feature = "testing"))]
if let auth::Backend::Local(_) = &auth_backend {
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -462,7 +498,13 @@ pub async fn run() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), {
move || {
#[cfg(any(test, feature = "testing"))]
refresh_config_notify.notify_one();
}
}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -653,6 +695,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
};
let config = Box::leak(Box::new(config));
@@ -806,6 +850,19 @@ fn build_auth_backend(
Ok(Either::Right(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(
crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)),
);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}

View File

@@ -165,7 +165,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
ComputeCredentialKeys::JwtPayload(_) => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,

View File

@@ -16,6 +16,17 @@ use crate::serverless::cancel_set::CancelSet;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::types::RoleName;
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
@@ -26,6 +37,8 @@ pub struct ProxyConfig {
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
#[cfg(feature = "testing")]
pub disable_pg_session_jwt: bool,
}
pub struct ComputeConfig {
@@ -409,6 +422,135 @@ impl FromStr for ConcurrencyLockOptions {
}
}
#[derive(Error, Debug)]
pub(crate) enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
std::result::Result::Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
pub(crate) async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
std::result::Result::Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -115,7 +115,8 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
self.config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -129,7 +130,7 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
keys,
})
}
crate::auth::Backend::Local(_) => {
@@ -256,6 +257,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
disable_pg_session_jwt: bool,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
@@ -277,7 +279,7 @@ impl PoolingBackend {
.expect("semaphore should never be closed");
// check again for race
if !self.local_pool.initialized(&conn_info) {
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
@@ -313,14 +315,16 @@ impl PoolingBackend {
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
.dbname(&conn_info.dbname);
if !disable_pg_session_jwt {
config.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
@@ -345,9 +349,11 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard();
return Err(e.into());
if !disable_pg_session_jwt {
if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard();
return Err(e.into());
}
}
info!("backend session state initialized");

View File

@@ -1,5 +1,93 @@
use http::StatusCode;
use http::header::HeaderName;
use crate::auth::ComputeUserInfoParseError;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::ReadBodyError;
pub trait HttpCodeError {
fn get_http_status_code(&self) -> StatusCode;
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}

View File

@@ -20,9 +20,12 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();

View File

@@ -3,11 +3,43 @@
use anyhow::Context;
use bytes::Bytes;
use http::{Response, StatusCode};
use http::header::AUTHORIZATION;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::AuthData;
use super::conn_pool::ConnInfoWithAuth;
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::types::{DbName, EndpointId, RoleName};
// Common header names used across serverless modules
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
HeaderName::from_static("neon-batch-isolation-level");
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
@@ -107,3 +139,136 @@ pub(crate) fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
pub(crate) fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: Option<&str>,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_url = match connection_string {
Some(connection_string) => Url::parse(connection_string)?,
None => {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
Url::parse(connection_string)?
}
};
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}

View File

@@ -29,13 +29,13 @@ use futures::future::{Either, select};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::SeedableRng;
use rand::rngs::StdRng;
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;

View File

@@ -1,49 +1,45 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::{Either, select, try_join};
use futures::{StreamExt, TryFutureExt};
use http::Method;
use http::header::AUTHORIZATION;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http::{Method, header::AUTHORIZATION};
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use hyper::{
Request, Response, StatusCode, header,
http::{HeaderName, HeaderValue},
};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
use serde_json::{Value, value::RawValue};
use std::pin::pin;
use std::sync::Arc;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{Level, debug, error, info};
use typed_json::json;
use url::Url;
use uuid::Uuid;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool::AuthData;
use super::conn_pool_lib::{self, ConnInfo};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
use super::http_util::{
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
};
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::ComputeUserInfoParseError;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::http::read_body_with_limit;
use crate::metrics::{HttpDirection, Metrics};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, EndpointId, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
@@ -70,16 +66,6 @@ enum Payload {
Batch(BatchQueryData),
}
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
@@ -91,179 +77,6 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -532,45 +345,6 @@ impl HttpCodeError for SqlOverHttpError {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -661,7 +435,7 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(&config.authentication_config, ctx, request.headers())?;
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
@@ -747,9 +521,17 @@ async fn handle_db_inner(
ComputeCredentialKeys::JwtPayload(payload)
if backend.auth_backend.is_local_proxy() =>
{
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
#[cfg(feature = "testing")]
let disable_pg_session_jwt = config.disable_pg_session_jwt;
#[cfg(not(feature = "testing"))]
let disable_pg_session_jwt = false;
let mut client = backend
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
.await?;
if !disable_pg_session_jwt {
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
}
_ => {
@@ -848,12 +630,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&TXN_DEFERRABLE,
];
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
async fn handle_auth_broker_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -883,7 +659,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body)
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress