diff --git a/proxy/README.md b/proxy/README.md index 583db36f28..e10ff3d710 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -138,3 +138,62 @@ Now from client you can start a new session: ```sh PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full" ``` + +## auth broker setup: + +Create a postgres instance: +```sh +docker run \ + --detach \ + --name proxy-postgres \ + --env POSTGRES_HOST_AUTH_METHOD=trust \ + --env POSTGRES_USER=authenticated \ + --env POSTGRES_DB=database \ + --publish 5432:5432 \ + postgres:17-bookworm +``` + +Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs) +```sh +{ + "jwks": [ + { + "id": "1", + "role_names": ["authenticator", "authenticated", "anon"], + "jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json", + "provider_name": "foo", + "jwt_audience": null + } + ] +} +``` + +Start the local proxy: +```sh +cargo run --bin local_proxy -- \ + --disable_pg_session_jwt true \ + --http 0.0.0.0:7432 +``` + +Start the auth broker: +```sh +LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \ + -c server.crt -k server.key \ + --is-auth-broker true \ + --wss 0.0.0.0:8080 \ + --http 0.0.0.0:7002 \ + --auth-backend local +``` + +Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable. +```sh +export NEON_JWT="..." +``` + +Run a query against the auth broker: +```sh +curl -k "https://foo.local.neon.build:8080/sql" \ + -H "Authorization: Bearer $NEON_JWT" \ + -H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \ + -d '{"query":"select 1","params":[]}' +``` diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 2e3013ead0..8fc3ea1978 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -171,7 +171,6 @@ impl ComputeUserInfo { pub(crate) enum ComputeCredentialKeys { AuthKeys(AuthKeys), JwtPayload(Vec), - None, } impl TryFrom for ComputeUserInfo { diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 423ecf821e..04cc7b3907 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -1,43 +1,39 @@ use std::net::SocketAddr; use std::pin::pin; -use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use anyhow::{Context, bail, ensure}; +use anyhow::bail; use arc_swap::ArcSwapOption; -use camino::{Utf8Path, Utf8PathBuf}; +use camino::Utf8PathBuf; use clap::Parser; -use compute_api::spec::LocalProxySpec; + use futures::future::Either; -use thiserror::Error; + use tokio::net::TcpListener; use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; use utils::sentry_init::init_sentry; use utils::{pid_file, project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend}; +use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; +use crate::config::refresh_config_loop; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, }; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; -use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; -use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; use crate::tls::client_config::compute_client_config_with_root_certs; -use crate::types::RoleName; use crate::url::ApiUrl; project_git_version!(GIT_VERSION); @@ -82,6 +78,11 @@ struct LocalProxyCliArgs { /// Path of the local proxy PID file #[clap(long, default_value = "./local_proxy.pid")] pid_path: Utf8PathBuf, + /// Disable pg_session_jwt extension installation + /// This is useful for testing the local proxy with vanilla postgres. + #[clap(long, default_value = "false")] + #[cfg(feature = "testing")] + disable_pg_session_jwt: bool, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -282,6 +283,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, + #[cfg(feature = "testing")] + disable_pg_session_jwt: args.disable_pg_session_jwt, }))) } @@ -293,132 +296,3 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati Box::leak(Box::new(auth_backend)) } - -#[derive(Error, Debug)] -enum RefreshConfigError { - #[error(transparent)] - Read(#[from] std::io::Error), - #[error(transparent)] - Parse(#[from] serde_json::Error), - #[error(transparent)] - Validate(anyhow::Error), - #[error(transparent)] - Tls(anyhow::Error), -} - -async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc) { - let mut init = true; - loop { - rx.notified().await; - - match refresh_config_inner(config, &path).await { - Ok(()) => {} - // don't log for file not found errors if this is the first time we are checking - // for computes that don't use local_proxy, this is not an error. - Err(RefreshConfigError::Read(e)) - if init && e.kind() == std::io::ErrorKind::NotFound => - { - debug!(error=?e, ?path, "could not read config file"); - } - Err(RefreshConfigError::Tls(e)) => { - error!(error=?e, ?path, "could not read TLS certificates"); - } - Err(e) => { - error!(error=?e, ?path, "could not read config file"); - } - } - - init = false; - } -} - -async fn refresh_config_inner( - config: &ProxyConfig, - path: &Utf8Path, -) -> Result<(), RefreshConfigError> { - let bytes = tokio::fs::read(&path).await?; - let data: LocalProxySpec = serde_json::from_slice(&bytes)?; - - let mut jwks_set = vec![]; - - fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result { - let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?; - - ensure!( - jwks_url.has_authority() - && (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"), - "Invalid JWKS url. Must be HTTP", - ); - - ensure!( - jwks_url.host().is_some_and(|h| h != url::Host::Domain("")), - "Invalid JWKS url. No domain listed", - ); - - // clear username, password and ports - jwks_url - .set_username("") - .expect("url can be a base and has a valid host and is not a file. should not error"); - jwks_url - .set_password(None) - .expect("url can be a base and has a valid host and is not a file. should not error"); - // local testing is hard if we need to have a specific restricted port - if cfg!(not(feature = "testing")) { - jwks_url.set_port(None).expect( - "url can be a base and has a valid host and is not a file. should not error", - ); - } - - // clear query params - jwks_url.set_fragment(None); - jwks_url.query_pairs_mut().clear().finish(); - - if jwks_url.scheme() != "https" { - // local testing is hard if we need to set up https support. - if cfg!(not(feature = "testing")) { - jwks_url - .set_scheme("https") - .expect("should not error to set the scheme to https if it was http"); - } else { - warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS"); - } - } - - Ok(JwksSettings { - id: jwks.id, - jwks_url, - _provider_name: jwks.provider_name, - jwt_audience: jwks.jwt_audience, - role_names: jwks - .role_names - .into_iter() - .map(RoleName::from) - .map(|s| RoleNameInt::from(&s)) - .collect(), - }) - } - - for jwks in data.jwks.into_iter().flatten() { - jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?); - } - - info!("successfully loaded new config"); - JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set }))); - - if let Some(tls_config) = data.tls { - let tls_config = tokio::task::spawn_blocking(move || { - crate::tls::server_config::configure_tls( - tls_config.key_path.as_ref(), - tls_config.cert_path.as_ref(), - None, - false, - ) - }) - .await - .propagate_task_panic() - .map_err(RefreshConfigError::Tls)?; - config.tls_config.store(Some(Arc::new(tls_config))); - } - - Ok(()) -} diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 2133f33a4d..7522dd5162 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -22,9 +22,13 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; +#[cfg(any(test, feature = "testing"))] +use crate::auth::backend::local::LocalBackend; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::batch::BatchQueue; use crate::cancellation::{CancellationHandler, CancellationProcessor}; +#[cfg(any(test, feature = "testing"))] +use crate::config::refresh_config_loop; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, @@ -43,6 +47,10 @@ use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; use crate::{auth, control_plane, http, serverless, usage_metrics}; +#[cfg(any(test, feature = "testing"))] +use camino::Utf8PathBuf; +#[cfg(any(test, feature = "testing"))] +use tokio::sync::Notify; project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); @@ -60,6 +68,9 @@ enum AuthBackendType { #[cfg(any(test, feature = "testing"))] Postgres, + + #[cfg(any(test, feature = "testing"))] + Local, } /// Neon proxy/router @@ -74,6 +85,10 @@ struct ProxyCliArgs { proxy: SocketAddr, #[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)] auth_backend: AuthBackendType, + /// Path of the local proxy config file (used for local-file auth backend) + #[clap(long, default_value = "./local_proxy.json")] + #[cfg(any(test, feature = "testing"))] + config_path: Utf8PathBuf, /// listen for management callback connection on ip:port #[clap(short, long, default_value = "127.0.0.1:7000")] mgmt: SocketAddr, @@ -226,6 +241,14 @@ struct ProxyCliArgs { #[clap(flatten)] pg_sni_router: PgSniRouterArgs, + + /// if this is not local proxy, this toggles whether we accept Postgres REST requests + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + is_rest_broker: bool, + + /// cache for `db_schema_cache` introspection (use `size=0` to disable) + #[clap(long, default_value = "size=1000,ttl=1h")] + db_schema_cache: String, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -386,6 +409,8 @@ pub async fn run() -> anyhow::Result<()> { 64, )); + #[cfg(any(test, feature = "testing"))] + let refresh_config_notify = Arc::new(Notify::new()); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) let mut client_tasks = JoinSet::new(); @@ -412,6 +437,17 @@ pub async fn run() -> anyhow::Result<()> { endpoint_rate_limiter.clone(), )); } + + // if auth backend is local, we need to load the config file + #[cfg(any(test, feature = "testing"))] + if let auth::Backend::Local(_) = &auth_backend { + refresh_config_notify.notify_one(); + tokio::spawn(refresh_config_loop( + config, + args.config_path, + refresh_config_notify.clone(), + )); + } } Either::Right(auth_backend) => { if let Some(proxy_listener) = proxy_listener { @@ -462,7 +498,13 @@ pub async fn run() -> anyhow::Result<()> { // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); - maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {})); + + maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), { + move || { + #[cfg(any(test, feature = "testing"))] + refresh_config_notify.notify_one(); + } + })); maintenance_tasks.spawn(http::health_server::task_main( http_listener, AppMetrics { @@ -653,6 +695,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, + #[cfg(feature = "testing")] + disable_pg_session_jwt: false, }; let config = Box::leak(Box::new(config)); @@ -806,6 +850,19 @@ fn build_auth_backend( Ok(Either::Right(config)) } + + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Local => { + let postgres: SocketAddr = "127.0.0.1:7432".parse()?; + let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?; + let auth_backend = crate::auth::Backend::Local( + crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)), + ); + + let config = Box::leak(Box::new(auth_backend)); + + Ok(Either::Left(config)) + } } } diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 0a19090ce0..7b9183b05e 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -165,7 +165,7 @@ impl AuthInfo { ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => { Some(Auth::Scram(Box::new(auth_keys))) } - ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None, + ComputeCredentialKeys::JwtPayload(_) => None, }, server_params: StartupMessageParams::default(), skip_db_user: false, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index cee15ac7fa..d5e6e1e4cb 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -16,6 +16,17 @@ use crate::serverless::cancel_set::CancelSet; pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::Host; +use crate::auth::backend::local::JWKS_ROLE_MAP; +use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; +use crate::ext::TaskExt; +use crate::intern::RoleNameInt; +use crate::types::RoleName; +use camino::{Utf8Path, Utf8PathBuf}; +use compute_api::spec::LocalProxySpec; +use thiserror::Error; +use tokio::sync::Notify; +use tracing::{debug, error, info, warn}; + pub struct ProxyConfig { pub tls_config: ArcSwapOption, pub metric_collection: Option, @@ -26,6 +37,8 @@ pub struct ProxyConfig { pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, pub connect_to_compute: ComputeConfig, + #[cfg(feature = "testing")] + pub disable_pg_session_jwt: bool, } pub struct ComputeConfig { @@ -409,6 +422,135 @@ impl FromStr for ConcurrencyLockOptions { } } +#[derive(Error, Debug)] +pub(crate) enum RefreshConfigError { + #[error(transparent)] + Read(#[from] std::io::Error), + #[error(transparent)] + Parse(#[from] serde_json::Error), + #[error(transparent)] + Validate(anyhow::Error), + #[error(transparent)] + Tls(anyhow::Error), +} + +pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc) { + let mut init = true; + loop { + rx.notified().await; + + match refresh_config_inner(config, &path).await { + std::result::Result::Ok(()) => {} + // don't log for file not found errors if this is the first time we are checking + // for computes that don't use local_proxy, this is not an error. + Err(RefreshConfigError::Read(e)) + if init && e.kind() == std::io::ErrorKind::NotFound => + { + debug!(error=?e, ?path, "could not read config file"); + } + Err(RefreshConfigError::Tls(e)) => { + error!(error=?e, ?path, "could not read TLS certificates"); + } + Err(e) => { + error!(error=?e, ?path, "could not read config file"); + } + } + + init = false; + } +} + +pub(crate) async fn refresh_config_inner( + config: &ProxyConfig, + path: &Utf8Path, +) -> Result<(), RefreshConfigError> { + let bytes = tokio::fs::read(&path).await?; + let data: LocalProxySpec = serde_json::from_slice(&bytes)?; + + let mut jwks_set = vec![]; + + fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result { + let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?; + + ensure!( + jwks_url.has_authority() + && (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"), + "Invalid JWKS url. Must be HTTP", + ); + + ensure!( + jwks_url.host().is_some_and(|h| h != url::Host::Domain("")), + "Invalid JWKS url. No domain listed", + ); + + // clear username, password and ports + jwks_url + .set_username("") + .expect("url can be a base and has a valid host and is not a file. should not error"); + jwks_url + .set_password(None) + .expect("url can be a base and has a valid host and is not a file. should not error"); + // local testing is hard if we need to have a specific restricted port + if cfg!(not(feature = "testing")) { + jwks_url.set_port(None).expect( + "url can be a base and has a valid host and is not a file. should not error", + ); + } + + // clear query params + jwks_url.set_fragment(None); + jwks_url.query_pairs_mut().clear().finish(); + + if jwks_url.scheme() != "https" { + // local testing is hard if we need to set up https support. + if cfg!(not(feature = "testing")) { + jwks_url + .set_scheme("https") + .expect("should not error to set the scheme to https if it was http"); + } else { + warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS"); + } + } + + Ok(JwksSettings { + id: jwks.id, + jwks_url, + _provider_name: jwks.provider_name, + jwt_audience: jwks.jwt_audience, + role_names: jwks + .role_names + .into_iter() + .map(RoleName::from) + .map(|s| RoleNameInt::from(&s)) + .collect(), + }) + } + + for jwks in data.jwks.into_iter().flatten() { + jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?); + } + + info!("successfully loaded new config"); + JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set }))); + + if let Some(tls_config) = data.tls { + let tls_config = tokio::task::spawn_blocking(move || { + crate::tls::server_config::configure_tls( + tls_config.key_path.as_ref(), + tls_config.cert_path.as_ref(), + None, + false, + ) + }) + .await + .propagate_task_panic() + .map_err(RefreshConfigError::Tls)?; + config.tls_config.store(Some(Arc::new(tls_config))); + } + + std::result::Result::Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 7708342ae3..4b3f379e76 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -115,7 +115,8 @@ impl PoolingBackend { match &self.auth_backend { crate::auth::Backend::ControlPlane(console, ()) => { - self.config + let keys = self + .config .authentication_config .jwks_cache .check_jwt( @@ -129,7 +130,7 @@ impl PoolingBackend { Ok(ComputeCredentials { info: user_info.clone(), - keys: crate::auth::backend::ComputeCredentialKeys::None, + keys, }) } crate::auth::Backend::Local(_) => { @@ -256,6 +257,7 @@ impl PoolingBackend { &self, ctx: &RequestContext, conn_info: ConnInfo, + disable_pg_session_jwt: bool, ) -> Result, HttpConnError> { if let Some(client) = self.local_pool.get(ctx, &conn_info)? { return Ok(client); @@ -277,7 +279,7 @@ impl PoolingBackend { .expect("semaphore should never be closed"); // check again for race - if !self.local_pool.initialized(&conn_info) { + if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt { local_backend .compute_ctl .install_extension(&ExtensionInstallRequest { @@ -313,14 +315,16 @@ impl PoolingBackend { .to_postgres_client_config(); config .user(&conn_info.user_info.user) - .dbname(&conn_info.dbname) - .set_param( + .dbname(&conn_info.dbname); + if !disable_pg_session_jwt { + config.set_param( "options", &format!( "-c pg_session_jwt.jwk={}", serde_json::to_string(&jwk).expect("serializing jwk to json should not fail") ), ); + } let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (client, connection) = config.connect(&postgres_client::NoTls).await?; @@ -345,9 +349,11 @@ impl PoolingBackend { debug!("setting up backend session state"); // initiates the auth session - if let Err(e) = client.batch_execute("select auth.init();").await { - discard.discard(); - return Err(e.into()); + if !disable_pg_session_jwt { + if let Err(e) = client.batch_execute("select auth.init();").await { + discard.discard(); + return Err(e.into()); + } } info!("backend session state initialized"); diff --git a/proxy/src/serverless/error.rs b/proxy/src/serverless/error.rs index 323c91baa5..786964e764 100644 --- a/proxy/src/serverless/error.rs +++ b/proxy/src/serverless/error.rs @@ -1,5 +1,93 @@ use http::StatusCode; +use http::header::HeaderName; + +use crate::auth::ComputeUserInfoParseError; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::http::ReadBodyError; pub trait HttpCodeError { fn get_http_status_code(&self) -> StatusCode; } + +#[derive(Debug, thiserror::Error)] +pub(crate) enum ConnInfoError { + #[error("invalid header: {0}")] + InvalidHeader(&'static HeaderName), + #[error("invalid connection string: {0}")] + UrlParseError(#[from] url::ParseError), + #[error("incorrect scheme")] + IncorrectScheme, + #[error("missing database name")] + MissingDbName, + #[error("invalid database name")] + InvalidDbName, + #[error("missing username")] + MissingUsername, + #[error("invalid username: {0}")] + InvalidUsername(#[from] std::string::FromUtf8Error), + #[error("missing authentication credentials: {0}")] + MissingCredentials(Credentials), + #[error("missing hostname")] + MissingHostname, + #[error("invalid hostname: {0}")] + InvalidEndpoint(#[from] ComputeUserInfoParseError), +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum Credentials { + #[error("required password")] + Password, + #[error("required authorization bearer token in JWT format")] + BearerJwt, +} + +impl ReportableError for ConnInfoError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} + +impl UserFacingError for ConnInfoError { + fn to_string_client(&self) -> String { + self.to_string() + } +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum ReadPayloadError { + #[error("could not read the HTTP request body: {0}")] + Read(#[from] hyper::Error), + #[error("request is too large (max is {limit} bytes)")] + BodyTooLarge { limit: usize }, + #[error("could not parse the HTTP request body: {0}")] + Parse(#[from] serde_json::Error), +} + +impl From> for ReadPayloadError { + fn from(value: ReadBodyError) -> Self { + match value { + ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit }, + ReadBodyError::Read(e) => Self::Read(e), + } + } +} + +impl ReportableError for ReadPayloadError { + fn get_error_kind(&self) -> ErrorKind { + match self { + ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect, + ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User, + ReadPayloadError::Parse(_) => ErrorKind::User, + } + } +} + +impl HttpCodeError for ReadPayloadError { + fn get_http_status_code(&self) -> StatusCode { + match self { + ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST, + ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE, + ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST, + } + } +} diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 1c6574e57e..18f7ecc0b1 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -20,9 +20,12 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; -pub(crate) type Send = http2::SendRequest; -pub(crate) type Connect = http2::Connection, hyper::body::Incoming, TokioExecutor>; +pub(crate) type Send = http2::SendRequest>; +pub(crate) type Connect = + http2::Connection, BoxBody, TokioExecutor>; #[derive(Clone)] pub(crate) struct ClientDataHttp(); diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index 95a28663a5..c876d8f096 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -3,11 +3,43 @@ use anyhow::Context; use bytes::Bytes; -use http::{Response, StatusCode}; +use http::header::AUTHORIZATION; +use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; use http_utils::error::ApiError; use serde::Serialize; +use url::Url; +use uuid::Uuid; + +use super::conn_pool::AuthData; +use super::conn_pool::ConnInfoWithAuth; +use super::conn_pool_lib::ConnInfo; +use super::error::{ConnInfoError, Credentials}; +use crate::auth::backend::ComputeUserInfo; +use crate::config::AuthenticationConfig; +use crate::context::RequestContext; +use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::NeonOptions; +use crate::types::{DbName, EndpointId, RoleName}; + +// Common header names used across serverless modules +pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id"); +pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); +pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); +pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); +pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); +pub(super) static TXN_ISOLATION_LEVEL: HeaderName = + HeaderName::from_static("neon-batch-isolation-level"); +pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); +pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable"); + +pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue { + let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH]; + HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..])) + .expect("uuid hyphenated format should be all valid header characters") +} /// Like [`ApiError::into_response`] pub(crate) fn api_error_into_response(this: ApiError) -> Response> { @@ -107,3 +139,136 @@ pub(crate) fn json_response( .map_err(|e| ApiError::InternalServerError(e.into()))?; Ok(response) } + +pub(crate) fn get_conn_info( + config: &'static AuthenticationConfig, + ctx: &RequestContext, + connection_string: Option<&str>, + headers: &HeaderMap, +) -> Result { + let connection_url = match connection_string { + Some(connection_string) => Url::parse(connection_string)?, + None => { + let connection_string = headers + .get(&CONN_STRING) + .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; + Url::parse(connection_string)? + } + }; + + let protocol = connection_url.scheme(); + if protocol != "postgres" && protocol != "postgresql" { + return Err(ConnInfoError::IncorrectScheme); + } + + let mut url_path = connection_url + .path_segments() + .ok_or(ConnInfoError::MissingDbName)?; + + let dbname: DbName = + urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); + ctx.set_dbname(dbname.clone()); + + let username = RoleName::from(urlencoding::decode(connection_url.username())?); + if username.is_empty() { + return Err(ConnInfoError::MissingUsername); + } + ctx.set_user(username.clone()); + // TODO: make sure this is right in the context of rest broker + let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { + if !config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::Password)); + } + + let auth = auth + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; + AuthData::Jwt( + auth.strip_prefix("Bearer ") + .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? + .into(), + ) + } else if let Some(pass) = connection_url.password() { + // wrong credentials provided + if config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); + } + + AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { + std::borrow::Cow::Borrowed(b) => b.into(), + std::borrow::Cow::Owned(b) => b.into(), + }) + } else if config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); + } else { + return Err(ConnInfoError::MissingCredentials(Credentials::Password)); + }; + let endpoint: EndpointId = match connection_url.host() { + Some(url::Host::Domain(hostname)) => hostname + .split_once('.') + .map_or(hostname, |(prefix, _)| prefix) + .into(), + Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { + return Err(ConnInfoError::MissingHostname); + } + }; + ctx.set_endpoint_id(endpoint.clone()); + + let pairs = connection_url.query_pairs(); + + let mut options = Option::None; + + let mut params = StartupMessageParams::default(); + params.insert("user", &username); + params.insert("database", &dbname); + for (key, value) in pairs { + params.insert(&key, &value); + if key == "options" { + options = Some(NeonOptions::parse_options_raw(&value)); + } + } + + // check the URL that was used, for metrics + { + let host_endpoint = headers + // get the host header + .get("host") + // extract the domain + .and_then(|h| { + let (host, _port) = h.to_str().ok()?.split_once(':')?; + Some(host) + }) + // get the endpoint prefix + .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix)); + + let kind = if host_endpoint == Some(&*endpoint) { + SniKind::Sni + } else { + SniKind::NoSni + }; + + let protocol = ctx.protocol(); + Metrics::get() + .proxy + .accepted_connections_by_sni + .inc(SniGroup { protocol, kind }); + } + + ctx.set_user_agent( + headers + .get(hyper::header::USER_AGENT) + .and_then(|h| h.to_str().ok()) + .map(Into::into), + ); + + let user_info = ComputeUserInfo { + endpoint, + user: username, + options: options.unwrap_or_default(), + }; + + let conn_info = ConnInfo { user_info, dbname }; + Ok(ConnInfoWithAuth { conn_info, auth }) +} diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index d8942bb814..5b7289c53d 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -29,13 +29,13 @@ use futures::future::{Either, select}; use http::{Method, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty}; +use http_util::{NEON_REQUEST_ID, uuid_to_header_value}; use http_utils::error::ApiError; use hyper::body::Incoming; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::SeedableRng; use rand::rngs::StdRng; -use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::time::timeout; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 5b348d59af..a901a47746 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,49 +1,45 @@ -use std::pin::pin; -use std::sync::Arc; - use bytes::Bytes; use futures::future::{Either, select, try_join}; use futures::{StreamExt, TryFutureExt}; -use http::Method; -use http::header::AUTHORIZATION; -use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http::{Method, header::AUTHORIZATION}; +use http_body_util::{BodyExt, Full, combinators::BoxBody}; use http_utils::error::ApiError; use hyper::body::Incoming; -use hyper::http::{HeaderName, HeaderValue}; -use hyper::{HeaderMap, Request, Response, StatusCode, header}; +use hyper::{ + Request, Response, StatusCode, header, + http::{HeaderName, HeaderValue}, +}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; use serde::Serialize; -use serde_json::Value; -use serde_json::value::RawValue; +use serde_json::{Value, value::RawValue}; +use std::pin::pin; +use std::sync::Arc; use tokio::time::{self, Instant}; use tokio_util::sync::CancellationToken; use tracing::{Level, debug, error, info}; use typed_json::json; -use url::Url; -use uuid::Uuid; use super::backend::{LocalProxyConnError, PoolingBackend}; -use super::conn_pool::{AuthData, ConnInfoWithAuth}; +use super::conn_pool::AuthData; use super::conn_pool_lib::{self, ConnInfo}; -use super::error::HttpCodeError; -use super::http_util::json_response; +use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError}; +use super::http_util::{ + ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE, + TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value, +}; use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json}; -use crate::auth::ComputeUserInfoParseError; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig}; +use crate::auth::backend::ComputeCredentialKeys; + +use crate::config::{HttpConfig, ProxyConfig}; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::http::{ReadBodyError, read_body_with_limit}; -use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; -use crate::pqproto::StartupMessageParams; -use crate::proxy::NeonOptions; +use crate::http::read_body_with_limit; +use crate::metrics::{HttpDirection, Metrics}; use crate::serverless::backend::HttpConnError; -use crate::types::{DbName, EndpointId, RoleName}; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; use crate::util::run_until_cancelled; @@ -70,16 +66,6 @@ enum Payload { Batch(BatchQueryData), } -pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id"); - -static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); -static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); -static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); -static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); -static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level"); -static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); -static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable"); - static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> @@ -91,179 +77,6 @@ where Ok(json_to_pg_text(json)) } -#[derive(Debug, thiserror::Error)] -pub(crate) enum ConnInfoError { - #[error("invalid header: {0}")] - InvalidHeader(&'static HeaderName), - #[error("invalid connection string: {0}")] - UrlParseError(#[from] url::ParseError), - #[error("incorrect scheme")] - IncorrectScheme, - #[error("missing database name")] - MissingDbName, - #[error("invalid database name")] - InvalidDbName, - #[error("missing username")] - MissingUsername, - #[error("invalid username: {0}")] - InvalidUsername(#[from] std::string::FromUtf8Error), - #[error("missing authentication credentials: {0}")] - MissingCredentials(Credentials), - #[error("missing hostname")] - MissingHostname, - #[error("invalid hostname: {0}")] - InvalidEndpoint(#[from] ComputeUserInfoParseError), -} - -#[derive(Debug, thiserror::Error)] -pub(crate) enum Credentials { - #[error("required password")] - Password, - #[error("required authorization bearer token in JWT format")] - BearerJwt, -} - -impl ReportableError for ConnInfoError { - fn get_error_kind(&self) -> ErrorKind { - ErrorKind::User - } -} - -impl UserFacingError for ConnInfoError { - fn to_string_client(&self) -> String { - self.to_string() - } -} - -fn get_conn_info( - config: &'static AuthenticationConfig, - ctx: &RequestContext, - headers: &HeaderMap, -) -> Result { - let connection_string = headers - .get(&CONN_STRING) - .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? - .to_str() - .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; - - let connection_url = Url::parse(connection_string)?; - - let protocol = connection_url.scheme(); - if protocol != "postgres" && protocol != "postgresql" { - return Err(ConnInfoError::IncorrectScheme); - } - - let mut url_path = connection_url - .path_segments() - .ok_or(ConnInfoError::MissingDbName)?; - - let dbname: DbName = - urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); - ctx.set_dbname(dbname.clone()); - - let username = RoleName::from(urlencoding::decode(connection_url.username())?); - if username.is_empty() { - return Err(ConnInfoError::MissingUsername); - } - ctx.set_user(username.clone()); - - let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { - if !config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - } - - let auth = auth - .to_str() - .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; - AuthData::Jwt( - auth.strip_prefix("Bearer ") - .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? - .into(), - ) - } else if let Some(pass) = connection_url.password() { - // wrong credentials provided - if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } - - AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { - std::borrow::Cow::Borrowed(b) => b.into(), - std::borrow::Cow::Owned(b) => b.into(), - }) - } else if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } else { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - }; - - let endpoint: EndpointId = match connection_url.host() { - Some(url::Host::Domain(hostname)) => hostname - .split_once('.') - .map_or(hostname, |(prefix, _)| prefix) - .into(), - Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { - return Err(ConnInfoError::MissingHostname); - } - }; - ctx.set_endpoint_id(endpoint.clone()); - - let pairs = connection_url.query_pairs(); - - let mut options = Option::None; - - let mut params = StartupMessageParams::default(); - params.insert("user", &username); - params.insert("database", &dbname); - for (key, value) in pairs { - params.insert(&key, &value); - if key == "options" { - options = Some(NeonOptions::parse_options_raw(&value)); - } - } - - // check the URL that was used, for metrics - { - let host_endpoint = headers - // get the host header - .get("host") - // extract the domain - .and_then(|h| { - let (host, _port) = h.to_str().ok()?.split_once(':')?; - Some(host) - }) - // get the endpoint prefix - .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix)); - - let kind = if host_endpoint == Some(&*endpoint) { - SniKind::Sni - } else { - SniKind::NoSni - }; - - let protocol = ctx.protocol(); - Metrics::get() - .proxy - .accepted_connections_by_sni - .inc(SniGroup { protocol, kind }); - } - - ctx.set_user_agent( - headers - .get(hyper::header::USER_AGENT) - .and_then(|h| h.to_str().ok()) - .map(Into::into), - ); - - let user_info = ComputeUserInfo { - endpoint, - user: username, - options: options.unwrap_or_default(), - }; - - let conn_info = ConnInfo { user_info, dbname }; - Ok(ConnInfoWithAuth { conn_info, auth }) -} - pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestContext, @@ -532,45 +345,6 @@ impl HttpCodeError for SqlOverHttpError { } } -#[derive(Debug, thiserror::Error)] -pub(crate) enum ReadPayloadError { - #[error("could not read the HTTP request body: {0}")] - Read(#[from] hyper::Error), - #[error("request is too large (max is {limit} bytes)")] - BodyTooLarge { limit: usize }, - #[error("could not parse the HTTP request body: {0}")] - Parse(#[from] serde_json::Error), -} - -impl From> for ReadPayloadError { - fn from(value: ReadBodyError) -> Self { - match value { - ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit }, - ReadBodyError::Read(e) => Self::Read(e), - } - } -} - -impl ReportableError for ReadPayloadError { - fn get_error_kind(&self) -> ErrorKind { - match self { - ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect, - ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User, - ReadPayloadError::Parse(_) => ErrorKind::User, - } - } -} - -impl HttpCodeError for ReadPayloadError { - fn get_http_status_code(&self) -> StatusCode { - match self { - ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST, - ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE, - ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST, - } - } -} - #[derive(Debug, thiserror::Error)] pub(crate) enum SqlOverHttpCancel { #[error("query was cancelled")] @@ -661,7 +435,7 @@ async fn handle_inner( "handling interactive connection from client" ); - let conn_info = get_conn_info(&config.authentication_config, ctx, request.headers())?; + let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?; info!( user = conn_info.conn_info.user_info.user.as_str(), "credentials" @@ -747,9 +521,17 @@ async fn handle_db_inner( ComputeCredentialKeys::JwtPayload(payload) if backend.auth_backend.is_local_proxy() => { - let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; - let (cli_inner, _dsc) = client.client_inner(); - cli_inner.set_jwt_session(&payload).await?; + #[cfg(feature = "testing")] + let disable_pg_session_jwt = config.disable_pg_session_jwt; + #[cfg(not(feature = "testing"))] + let disable_pg_session_jwt = false; + let mut client = backend + .connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt) + .await?; + if !disable_pg_session_jwt { + let (cli_inner, _dsc) = client.client_inner(); + cli_inner.set_jwt_session(&payload).await?; + } Client::Local(client) } _ => { @@ -848,12 +630,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[ &TXN_DEFERRABLE, ]; -pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue { - let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH]; - HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..])) - .expect("uuid hyphenated format should be all valid header characters") -} - async fn handle_auth_broker_inner( ctx: &RequestContext, request: Request, @@ -883,7 +659,7 @@ async fn handle_auth_broker_inner( req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())); let req = req - .body(body) + .body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here? .expect("all headers and params received via hyper should be valid for request"); // todo: map body to count egress