diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 5ad88276e8..394a4815bb 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -143,6 +143,19 @@ fn main() -> anyhow::Result<()> { return Ok(()); } + let auth = match args.auth_validation_public_key_path.as_ref() { + None => { + info!("auth is disabled"); + None + } + Some(path) => { + info!("loading JWT auth key from {}", path.display()); + Some(Arc::new( + JwtAuth::from_key_path(path).context("failed to load the auth key")?, + )) + } + }; + let conf = SafeKeeperConf { workdir, my_id: id, @@ -156,7 +169,7 @@ fn main() -> anyhow::Result<()> { max_offloader_lag_bytes: args.max_offloader_lag, backup_runtime_threads: args.wal_backup_threads, wal_backup_enabled: !args.disable_wal_backup, - auth_validation_public_key_path: args.auth_validation_public_key_path, + auth, }; // initialize sentry if SENTRY_DSN is provided @@ -186,19 +199,6 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { e })?; - let auth = match conf.auth_validation_public_key_path.as_ref() { - None => { - info!("auth is disabled"); - None - } - Some(path) => { - info!("loading JWT auth key from {}", path.display()); - Some(Arc::new( - JwtAuth::from_key_path(path).context("failed to load the auth key")?, - )) - } - }; - // Register metrics collector for active timelines. It's important to do this // after daemonizing, otherwise process collector will be upset. let timeline_collector = safekeeper::metrics::TimelineCollector::new(); @@ -212,12 +212,11 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { GlobalTimelines::init(conf.clone(), wal_backup_launcher_tx)?; let conf_ = conf.clone(); - let auth_ = auth.clone(); threads.push( thread::Builder::new() .name("http_endpoint_thread".into()) .spawn(|| { - let router = http::make_router(conf_, auth_); + let router = http::make_router(conf_); endpoint::serve_thread_main( router, http_listener, @@ -231,7 +230,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let safekeeper_thread = thread::Builder::new() .name("safekeeper thread".into()) .spawn(|| { - if let Err(e) = wal_service::thread_main(conf_cloned, pg_listener, auth) { + if let Err(e) = wal_service::thread_main(conf_cloned, pg_listener) { info!("safekeeper thread terminated: {e}"); } }) @@ -244,7 +243,6 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { thread::Builder::new() .name("broker thread".into()) .spawn(|| { - // TODO: add auth? broker::thread_main(conf_); })?, ); diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 05527303ca..c692e9fc12 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -15,9 +15,8 @@ use regex::Regex; use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; use std::str; -use std::sync::Arc; use tracing::info; -use utils::auth::{Claims, JwtAuth, Scope}; +use utils::auth::{Claims, Scope}; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, @@ -32,7 +31,6 @@ pub struct SafekeeperPostgresHandler { pub tenant_id: Option, pub timeline_id: Option, pub ttid: TenantTimelineId, - auth: Option>, claims: Option, } @@ -107,6 +105,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT // which requires auth to be present let data = self + .conf .auth .as_ref() .unwrap() @@ -166,14 +165,13 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { } impl SafekeeperPostgresHandler { - pub fn new(conf: SafeKeeperConf, auth: Option>) -> Self { + pub fn new(conf: SafeKeeperConf) -> Self { SafekeeperPostgresHandler { conf, appname: None, tenant_id: None, timeline_id: None, ttid: TenantTimelineId::empty(), - auth, claims: None, } } @@ -181,7 +179,7 @@ impl SafekeeperPostgresHandler { // when accessing management api supply None as an argument // when using to authorize tenant pass corresponding tenant id fn check_permission(&self, tenant_id: Option) -> Result<()> { - if self.auth.is_none() { + if self.conf.auth.is_none() { // auth is set to Trust, nothing to check so just return ok return Ok(()); } diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index a9a9eb3388..a917d61678 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -277,12 +277,9 @@ async fn record_safekeeper_info(mut request: Request) -> Result>, -) -> RouterBuilder { +pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder { let mut router = endpoint::make_router(); - if auth.is_some() { + if conf.auth.is_some() { router = router.middleware(auth_middleware(|request| { #[allow(clippy::mutable_key_type)] static ALLOWLIST_ROUTES: Lazy> = @@ -298,6 +295,7 @@ pub fn make_router( // NB: on any changes do not forget to update the OpenAPI spec // located nearby (/safekeeper/src/http/openapi_spec.yaml). + let auth = conf.auth.clone(); router .data(Arc::new(conf)) .data(auth) diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 5decfe64de..891d73533f 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -24,7 +24,9 @@ pub mod wal_service; pub mod wal_storage; mod timelines_global_map; +use std::sync::Arc; pub use timelines_global_map::GlobalTimelines; +use utils::auth::JwtAuth; pub mod defaults { pub use safekeeper_api::{ @@ -57,7 +59,7 @@ pub struct SafeKeeperConf { pub max_offloader_lag_bytes: u64, pub backup_runtime_threads: Option, pub wal_backup_enabled: bool, - pub auth_validation_public_key_path: Option, + pub auth: Option>, } impl SafeKeeperConf { @@ -87,7 +89,7 @@ impl SafeKeeperConf { broker_keepalive_interval: Duration::from_secs(5), backup_runtime_threads: None, wal_backup_enabled: true, - auth_validation_public_key_path: None, + auth: None, heartbeat_timeout: Duration::new(5, 0), max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES, } diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index fd8f9d9dcf..0fea00fe1b 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -5,32 +5,25 @@ use anyhow::Result; use regex::Regex; use std::net::{TcpListener, TcpStream}; -use std::sync::Arc; use std::thread; use tracing::*; -use utils::auth::JwtAuth; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; use utils::postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main( - conf: SafeKeeperConf, - listener: TcpListener, - auth: Option>, -) -> Result<()> { +pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> Result<()> { loop { match listener.accept() { Ok((socket, peer_addr)) => { debug!("accepted connection from {}", peer_addr); let conf = conf.clone(); - let auth = auth.clone(); let _ = thread::Builder::new() .name("WAL service thread".into()) .spawn(move || { - if let Err(err) = handle_socket(socket, conf, auth) { + if let Err(err) = handle_socket(socket, conf) { error!("connection handler exited: {}", err); } }) @@ -51,25 +44,17 @@ fn get_tid() -> u64 { /// This is run by `thread_main` above, inside a background thread. /// -fn handle_socket( - socket: TcpStream, - conf: SafeKeeperConf, - auth: Option>, -) -> Result<()> { +fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<()> { let _enter = info_span!("", tid = ?get_tid()).entered(); socket.set_nodelay(true)?; - let mut conn_handler = SafekeeperPostgresHandler::new(conf, auth.clone()); - let pgbackend = PostgresBackend::new( - socket, - match auth { - None => AuthType::Trust, - Some(_) => AuthType::NeonJWT, - }, - None, - false, - )?; + let auth_type = match conf.auth { + None => AuthType::Trust, + Some(_) => AuthType::NeonJWT, + }; + let mut conn_handler = SafekeeperPostgresHandler::new(conf); + let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?; // libpq replication protocol between safekeeper and replicas/pagers pgbackend.run(&mut conn_handler)?;