diff --git a/Cargo.lock b/Cargo.lock index 5f71af118c..cf16fd95c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5273,6 +5273,7 @@ dependencies = [ "tokio-rustls 0.26.2", "tokio-tungstenite 0.21.0", "tokio-util", + "toml", "tracing", "tracing-log", "tracing-opentelemetry", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index ce8610be24..c9d504f8bb 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -89,6 +89,7 @@ tokio-postgres = { workspace = true, optional = true } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } +toml.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true tracing.workspace = true diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 6c5b098a3e..70030839df 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -8,10 +8,11 @@ use std::time::Duration; #[cfg(any(test, feature = "testing"))] use anyhow::Context; -use anyhow::{bail, ensure}; +use anyhow::{bail, anyhow}; use arc_swap::ArcSwapOption; use futures::future::Either; use remote_storage::RemoteStorageConfig; +use serde::Deserialize; use tokio::net::TcpListener; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -39,7 +40,7 @@ use crate::serverless::cancel_set::CancelSet; use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; -use crate::{auth, control_plane, http, serverless, usage_metrics}; +use crate::{auth, control_plane, http, pglb, serverless, usage_metrics}; project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); @@ -59,6 +60,262 @@ enum AuthBackendType { Postgres, } +#[derive(Deserialize)] +struct Root { + #[serde(flatten)] + legacy: LegacyModes, + introspection: Introspection, +} + +#[derive(Deserialize)] +#[serde(untagged)] +enum LegacyModes { + Proxy { + pglb: Pglb, + neonkeeper: NeonKeeper, + http: Option, + pg_sni_router: Option, + }, + AuthBroker { + neonkeeper: NeonKeeper, + http: Http, + }, + ConsoleRedirect { + console_redirect: ConsoleRedirect, + }, +} + +#[derive(Deserialize)] +struct Pglb { + listener: Listener, +} + +#[derive(Deserialize)] +struct Listener { + /// address to bind to + addr: SocketAddr, + /// which header should we expect to see on this socket + /// from the load balancer + header: Option, + + /// certificates used for TLS. + /// first cert is the default. + /// TLS not used if no certs provided. + certs: Vec, + + /// Timeout to use for TLS handshake + timeout: Option, +} + +#[derive(Deserialize)] +enum ProxyHeader { + /// Accept the PROXY! protocol V2. + ProxyProtocolV2(ProxyProtocolV2Kind), +} + +#[derive(Deserialize)] +enum ProxyProtocolV2Kind { + /// Expect AWS TLVs in the header. + Aws, + /// Expect Azure TLVs in the header. + Azure, +} + +#[derive(Deserialize)] +struct KeyPair { + key: PathBuf, + cert: PathBuf, +} + +#[derive(Deserialize)] +/// The service that authenticates all incoming connection attempts, +/// provides monitoring and also wakes computes. +struct NeonKeeper { + cplane: ControlPlaneBackend, + redis: Option, + auth: Vec, + + /// map of endpoint->computeinfo + compute: Cache, + /// cache for GetEndpointAccessControls. + project_info_cache: config::ProjectInfoCacheOptions, + /// cache for all valid endpoints + endpoint_cache_config: config::EndpointCacheConfig, + + request_log_export: Option, + data_transfer_export: Option, +} + +#[derive(Deserialize)] +struct Redis { + /// Cancellation channel size (max queue size for redis kv client) + cancellation_ch_size: usize, + /// Cancellation ops batch size for redis + cancellation_batch_size: usize, + + auth: RedisAuthentication, +} + +#[derive(Deserialize)] +enum RedisAuthentication { + /// i don't remember what this stands for. + /// IAM roles for service accounts? + Irsa { + host: String, + port: u16, + cluster_name: Option, + user_id: Option, + aws_region: String, + }, + Basic { + url: url::Url, + }, +} + +#[derive(Deserialize)] +struct PgSniRouter { + /// The listener to use to proxy connections to compute, + /// assuming the compute does not support TLS. + listener: Listener, + + /// The listener to use to proxy connections to compute, + /// assuming the compute requires TLS. + listener_tls: Listener, + + /// append this domain zone to the SNI hostname to get the destination address + dest: String, +} + +#[derive(Deserialize)] +/// `psql -h pg.neon.tech`. +struct ConsoleRedirect { + /// Connection requests from clients. + listener: Listener, + /// Messages from control plane to accept the connection. + cplane: Listener, + + /// The base url to use for redirects. + console: url::Url, + + timeout: Duration, +} + +#[derive(Deserialize)] +enum ControlPlaneBackend { + /// Use the HTTP API to access the control plane. + Http(url::Url), + /// Stub the control plane with a postgres instance. + #[cfg(feature = "testing")] + PostgresMock(url::Url), +} + +#[derive(Deserialize)] +struct Http { + listener: Listener, + sql_over_http: SqlOverHttp, + + // todo: move into Pglb. + websockets: Option, +} + +#[derive(Deserialize)] +struct SqlOverHttp { + pool_max_conns_per_endpoint: usize, + pool_max_total_conns: usize, + pool_idle_timeout: Duration, + pool_gc_epoch: Duration, + pool_shards: usize, + + client_conn_threshold: u64, + cancel_set_shards: usize, + + timeout: Duration, + max_request_size_bytes: usize, + max_response_size_bytes: usize, + + auth: Vec, +} + +#[derive(Deserialize)] +enum AuthMechanism { + Sasl { + /// timeout for SASL handshake + timeout: Duration, + }, + CleartextPassword { + /// number of threads for the thread pool + threads: usize, + }, + // add something about the jwks cache i guess. + Jwt {}, +} + +#[derive(Deserialize)] +struct Websockets { + auth: Vec, +} + +#[derive(Deserialize)] +/// The HTTP API used for internal monitoring. +struct Introspection { + listener: Listener, +} + +#[derive(Deserialize)] +enum RequestLogExport { + Parquet { + location: RemoteStorageConfig, + disconnect: RemoteStorageConfig, + + /// The region identifier to tag the entries with. + region: String, + + /// How many rows to include in a row group + row_group_size: usize, + + /// How large each column page should be in bytes + page_size: usize, + + /// How large the total parquet file should be in bytes + size: i64, + + /// How long to wait before forcing a file upload + maximum_duration: tokio::time::Duration, + // /// What level of compression to use + // compression: Compression, + }, +} + +#[derive(Deserialize)] +enum Cache { + /// Expire by LRU or by idle. + /// Note: "live" in "time-to-live" actually means idle here. + LruTtl { + /// Max number of entries. + size: usize, + /// Entry's time-to-live. + ttl: Duration, + }, +} + +#[derive(Deserialize)] +struct DataTransferExport { + /// http endpoint to receive periodic metric updates + endpoint: Option, + /// how often metrics should be sent to a collection endpoint + interval: Option, + + /// interval for backup metric collection + backup_interval: std::time::Duration, + /// remote storage configuration for backup metric collection + /// Encoded as toml (same format as pageservers), eg + /// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}` + backup_remote_storage: Option, + /// chunk size for backup metric collection + /// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression. + backup_chunk_size: usize, +} + /// Neon proxy/router #[derive(Parser)] #[command(version = GIT_VERSION, about)] @@ -311,180 +568,113 @@ pub async fn run() -> anyhow::Result<()> { } }; - let args = ProxyCliArgs::parse(); - let config = build_config(&args)?; - let auth_backend = build_auth_backend(&args)?; - - match auth_backend { - Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"), - Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"), - } - info!("Using region: {}", args.aws_region); - let redis_client = configure_redis(&args).await?; - - // Check that we can bind to address before further initialization - info!("Starting http on {}", args.http); - let http_listener = TcpListener::bind(args.http).await?.into_std()?; - - info!("Starting mgmt on {}", args.mgmt); - let mgmt_listener = TcpListener::bind(args.mgmt).await?; - - let proxy_listener = if args.is_auth_broker { - None - } else { - info!("Starting proxy on {}", args.proxy); - Some(TcpListener::bind(args.proxy).await?) - }; - - let sni_router_listeners = { - let args = &args.pg_sni_router; - if args.dest.is_some() { - ensure!( - args.tls_key.is_some(), - "sni-router-tls-key must be provided" - ); - ensure!( - args.tls_cert.is_some(), - "sni-router-tls-cert must be provided" - ); - - info!( - "Starting pg-sni-router on {} and {}", - args.listen, args.listen_tls - ); - - Some(( - TcpListener::bind(args.listen).await?, - TcpListener::bind(args.listen_tls).await?, - )) - } else { - None - } - }; - - // TODO: rename the argument to something like serverless. - // It now covers more than just websockets, it also covers SQL over HTTP. - let serverless_listener = if let Some(serverless_address) = args.wss { - info!("Starting wss on {serverless_address}"); - Some(TcpListener::bind(serverless_address).await?) - } else if args.is_auth_broker { - bail!("wss arg must be present for auth-broker") - } else { - None - }; - - let cancellation_token = CancellationToken::new(); - - // channel size should be higher than redis client limit to avoid blocking - let cancel_ch_size = args.cancellation_ch_size; - let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size); - let cancellation_handler = Arc::new(CancellationHandler::new( - &config.connect_to_compute, - Some(tx_cancel), - )); - - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) - .unwrap_or(EndpointRateLimiter::DEFAULT), - 64, - )); + let config: Root = toml::from_str(&tokio::fs::read_to_string("proxy.toml").await?)?; // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) let mut client_tasks = JoinSet::new(); - match auth_backend { - Either::Left(auth_backend) => { - if let Some(proxy_listener) = proxy_listener { - client_tasks.spawn(crate::proxy::task_main( - config, - auth_backend, - proxy_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - endpoint_rate_limiter.clone(), - )); - } - - if let Some(serverless_listener) = serverless_listener { - client_tasks.spawn(serverless::task_main( - config, - auth_backend, - serverless_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - endpoint_rate_limiter.clone(), - )); - } - } - Either::Right(auth_backend) => { - if let Some(proxy_listener) = proxy_listener { - client_tasks.spawn(crate::console_redirect_proxy::task_main( - config, - auth_backend, - proxy_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - )); - } - } - } - - // spawn pg-sni-router mode. - if let Some((listen, listen_tls)) = sni_router_listeners { - let args = args.pg_sni_router; - let dest = args.dest.expect("already asserted it is set"); - let key_path = args.tls_key.expect("already asserted it is set"); - let cert_path = args.tls_cert.expect("already asserted it is set"); - - let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; - - let dest = Arc::new(dest); - - client_tasks.spawn(super::pg_sni_router::task_main( - dest.clone(), - tls_config.clone(), - None, - listen, - cancellation_token.clone(), - )); - - client_tasks.spawn(super::pg_sni_router::task_main( - dest, - tls_config, - Some(config.connect_to_compute.tls.clone()), - listen_tls, - cancellation_token.clone(), - )); - } - - client_tasks.spawn(crate::context::parquet::worker( - cancellation_token.clone(), - args.parquet_upload, - args.region, - )); // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); - maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {})); - maintenance_tasks.spawn(http::health_server::task_main( - http_listener, - AppMetrics { - jemalloc, - neon_metrics, - proxy: crate::metrics::Metrics::get(), - }, - )); - maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); - if let Some(metrics_config) = &config.metric_collection { - // TODO: Add gc regardles of the metric collection being enabled. - maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); - } + let cancellation_token = CancellationToken::new(); - #[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))] - if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend { - if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api { - if let Some(client) = redis_client { + match config.legacy { + LegacyModes::Proxy { + pglb, + neonkeeper, + http, + pg_sni_router, + } => { + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + // todo: use neonkeeper config. + EndpointRateLimiter::DEFAULT, + 64, + )); + + info!("Starting proxy on {}", pglb.listener.addr); + let proxy_listener = TcpListener::bind(pglb.listener.addr).await?; + + client_tasks.spawn(crate::proxy::task_main( + config, + auth_backend, + proxy_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + endpoint_rate_limiter.clone(), + )); + + if let Some(http) = http { + info!("Starting wss on {}", http.listener.addr); + let http_listener = TcpListener::bind(http.listener.addr).await?; + + client_tasks.spawn(serverless::task_main( + config, + auth_backend, + http_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + endpoint_rate_limiter.clone(), + )); + }; + + if let Some(redis) = neonkeeper.redis { + let client = configure_redis(redis.auth); + } + + if let Some(sni_router) = pg_sni_router { + let listen = TcpListener::bind(sni_router.listener.addr).await?; + let listen_tls = TcpListener::bind(sni_router.listener_tls.addr).await?; + + let [KeyPair { key, cert }] = sni_router + .listener + .certs + .try_into() + .map_err(|_| anyhow!("only 1 keypair is supported for pg-sni-router"))?; + + let tls_config = super::pg_sni_router::parse_tls(&key, &cert)?; + + let dest = Arc::new(sni_router.dest); + + client_tasks.spawn(super::pg_sni_router::task_main( + dest.clone(), + tls_config.clone(), + None, + listen, + cancellation_token.clone(), + )); + + client_tasks.spawn(super::pg_sni_router::task_main( + dest, + tls_config, + Some(config.connect_to_compute.tls.clone()), + listen_tls, + cancellation_token.clone(), + )); + } + + match neonkeeper.request_log_export { + Some(RequestLogExport::Parquet { + location, + disconnect, + region, + row_group_size, + page_size, + size, + maximum_duration, + }) => { + client_tasks.spawn(crate::context::parquet::worker( + cancellation_token.clone(), + args.parquet_upload, + args.region, + )); + } + None => {} + } + + if let (ControlPlaneBackend::Http(api), Some(redis)) = + (neonkeeper.cplane, neonkeeper.redis) + { // project info cache and invalidation of that cache. let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); @@ -518,6 +708,129 @@ pub async fn run() -> anyhow::Result<()> { ); } } + LegacyModes::AuthBroker { neonkeeper, http } => { + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + // todo: use neonkeeper config. + EndpointRateLimiter::DEFAULT, + 64, + )); + + info!("Starting wss on {}", http.listener.addr); + let http_listener = TcpListener::bind(http.listener.addr).await?; + + if let Some(redis) = neonkeeper.redis { + let client = configure_redis(redis.auth); + } + + client_tasks.spawn(serverless::task_main( + config, + auth_backend, + serverless_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + endpoint_rate_limiter.clone(), + )); + + match neonkeeper.request_log_export { + Some(RequestLogExport::Parquet { + location, + disconnect, + region, + row_group_size, + page_size, + size, + maximum_duration, + }) => { + client_tasks.spawn(crate::context::parquet::worker( + cancellation_token.clone(), + args.parquet_upload, + args.region, + )); + } + None => {} + } + + if let (ControlPlaneBackend::Http(api), Some(redis)) = + (neonkeeper.cplane, neonkeeper.redis) + { + // project info cache and invalidation of that cache. + let cache = api.caches.project_info.clone(); + maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + + // cancellation key management + let mut redis_kv_client = RedisKVClient::new(client.clone()); + maintenance_tasks.spawn(async move { + redis_kv_client.try_connect().await?; + handle_cancel_messages( + &mut redis_kv_client, + rx_cancel, + args.cancellation_batch_size, + ) + .await?; + + drop(redis_kv_client); + + // `handle_cancel_messages` was terminated due to the tx_cancel + // being dropped. this is not worthy of an error, and this task can only return `Err`, + // so let's wait forever instead. + std::future::pending().await + }); + + // listen for notifications of new projects/endpoints/branches + let cache = api.caches.endpoints_cache.clone(); + let span = tracing::info_span!("endpoints_cache"); + maintenance_tasks.spawn( + async move { cache.do_read(client, cancellation_token.clone()).await } + .instrument(span), + ); + } + } + LegacyModes::ConsoleRedirect { console_redirect } => { + info!("Starting proxy on {}", console_redirect.listener.addr); + let proxy_listener = TcpListener::bind(console_redirect.listener.addr).await?; + + info!("Starting mgmt on {}", console_redirect.listener.addr); + let mgmt_listener = TcpListener::bind(console_redirect.listener.addr).await?; + + client_tasks.spawn(crate::console_redirect_proxy::task_main( + config, + auth_backend, + proxy_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + )); + maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); + } + } + + // Check that we can bind to address before further initialization + info!("Starting http on {}", config.introspection.listener.addr); + let http_listener = TcpListener::bind(config.introspection.listener.addr) + .await? + .into_std()?; + + // channel size should be higher than redis client limit to avoid blocking + let cancel_ch_size = args.cancellation_ch_size; + let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size); + let cancellation_handler = Arc::new(CancellationHandler::new( + &config.connect_to_compute, + Some(tx_cancel), + )); + + maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {})); + maintenance_tasks.spawn(http::health_server::task_main( + http_listener, + AppMetrics { + jemalloc, + neon_metrics, + proxy: crate::metrics::Metrics::get(), + }, + )); + + if let Some(metrics_config) = &config.metric_collection { + // TODO: Add gc regardles of the metric collection being enabled. + maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); } let maintenance = loop { @@ -801,67 +1114,35 @@ fn build_auth_backend( } } -async fn configure_redis( - args: &ProxyCliArgs, -) -> anyhow::Result> { - // For some reason, we have two redis'. - // Why? - // In the past, we used to have a single global redis instance, - // as redis was only used for console<->cplane communication. - // - // After proxy started using redis, this needed fixing so we added the regional - // redis instances after. - // - // - // regional_redis is used for: - // 1. Stream of new endpoints/projects/branches. - // 2. KV for cancellation keys - // - // redis_notifications is used for: - // 1. Stream of new endpoints/projects/branches. - // - // In AWS, these are different[citation needed] - // In Azure, these are the same[citation needed] - // - // I think we can get rid of the notifications junk by now. To confirm. - - // TODO: untangle the config args - let redis_client = match args.redis_auth_type.as_deref() { - Some("plain") => match &args.redis_plain { - None => { - bail!("plain auth requires redis_plain to be set"); - } - Some(url) => { - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) - } - }, - Some("irsa") => match (&args.redis_host, args.redis_port) { - (Some(host), Some(port)) => Some( - ConnectionWithCredentialsProvider::new_with_credentials_provider( - host.clone(), - port, - elasticache::CredentialsProvider::new( - args.aws_region.clone(), - args.redis_cluster_name.clone(), - args.redis_user_id.clone(), - ) - .await, - ), - ), - // (None, None) => { - // // todo: upgrade to error? - // warn!( - // "irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client" - // ); - // None - // } - _ => { - bail!("redis-host and redis-port must be specified together") - } - }, - Some(auth_type) => { - bail!("unknown auth type {auth_type:?} given") +async fn configure_redis(auth: RedisAuthentication) -> ConnectionWithCredentialsProvider { + match auth { + RedisAuthentication::Irsa { + host, + port, + cluster_name, + user_id, + aws_region, + } => ConnectionWithCredentialsProvider::new_with_credentials_provider( + host, + port, + elasticache::CredentialsProvider::new(aws_region, cluster_name, user_id).await, + ), + RedisAuthentication::Basic { url } => { + ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()) } + } + } + None => None, + }; + + // let redis_notifications_client = if let Some(url) = &args.redis_notifications { + // Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url)) + // } else { + // regional_redis_client.clone() + // }; + + Ok(redis_client) + } None => None, }; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index cee15ac7fa..6e8a0756b9 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -69,7 +69,7 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } -#[derive(Debug)] +#[derive(Debug, serde::Deserialize)] pub struct EndpointCacheConfig { /// Batch size to receive all endpoints on the startup. pub initial_batch_size: usize, @@ -205,7 +205,7 @@ impl FromStr for CacheOptions { } /// Helper for cmdline cache options parsing. -#[derive(Debug)] +#[derive(Debug, serde::Deserialize)] pub struct ProjectInfoCacheOptions { /// Max number of entries. pub size: usize,