diff --git a/Cargo.lock b/Cargo.lock index 86e2ac3ba3..9233976f47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2578,6 +2578,7 @@ dependencies = [ "hyper", "jsonwebtoken", "lazy_static", + "nix", "postgres", "rand", "routerify", diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 3a577476dc..aae37a29c8 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -5,13 +5,12 @@ use serde::{Deserialize, Serialize}; use std::{ env, - net::TcpListener, path::{Path, PathBuf}, str::FromStr, thread, }; use tracing::*; -use zenith_utils::{auth::JwtAuth, logging, postgres_backend::AuthType}; +use zenith_utils::{auth::JwtAuth, logging, postgres_backend::AuthType, tcp_listener}; use anyhow::{bail, ensure, Context, Result}; use signal_hook::consts::signal::*; @@ -480,13 +479,13 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> { "Starting pageserver http handler on {}", conf.listen_http_addr ); - let http_listener = TcpListener::bind(conf.listen_http_addr.clone())?; + let http_listener = tcp_listener::bind(conf.listen_http_addr.clone())?; info!( "Starting pageserver pg protocol handler on {}", conf.listen_pg_addr ); - let pageserver_listener = TcpListener::bind(conf.listen_pg_addr.clone())?; + let pageserver_listener = tcp_listener::bind(conf.listen_pg_addr.clone())?; if conf.daemonize { info!("daemonizing..."); diff --git a/proxy/src/main.rs b/proxy/src/main.rs index c183785635..d62ff50ff8 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -7,7 +7,7 @@ /// use std::{ collections::HashMap, - net::{SocketAddr, TcpListener}, + net::SocketAddr, sync::{mpsc, Arc, Mutex}, thread, }; @@ -17,6 +17,7 @@ use clap::{App, Arg, ArgMatches}; use cplane_api::DatabaseInfo; use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig}; +use zenith_utils::tcp_listener; mod cplane_api; mod mgmt; @@ -140,10 +141,10 @@ fn main() -> anyhow::Result<()> { // Check that we can bind to address before further initialization println!("Starting proxy on {}", state.conf.proxy_address); - let pageserver_listener = TcpListener::bind(state.conf.proxy_address)?; + let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?; println!("Starting mgmt on {}", state.conf.mgmt_address); - let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?; + let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?; let threads = [ // Spawn a thread to listen for connections. It will spawn further threads diff --git a/walkeeper/src/bin/safekeeper.rs b/walkeeper/src/bin/safekeeper.rs index 618467ef65..2c7283d51a 100644 --- a/walkeeper/src/bin/safekeeper.rs +++ b/walkeeper/src/bin/safekeeper.rs @@ -7,11 +7,10 @@ use const_format::formatcp; use daemonize::Daemonize; use log::*; use std::env; -use std::net::TcpListener; use std::path::{Path, PathBuf}; use std::thread; use zenith_utils::http::endpoint; -use zenith_utils::logging; +use zenith_utils::{logging, tcp_listener}; use walkeeper::defaults::{DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_PG_LISTEN_ADDR}; use walkeeper::http; @@ -132,13 +131,13 @@ fn main() -> Result<()> { fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let log_file = logging::init("safekeeper.log", conf.daemonize)?; - let http_listener = TcpListener::bind(conf.listen_http_addr.clone()).map_err(|e| { + let http_listener = tcp_listener::bind(conf.listen_http_addr.clone()).map_err(|e| { error!("failed to bind to address {}: {}", conf.listen_http_addr, e); e })?; info!("Starting safekeeper on {}", conf.listen_pg_addr); - let pg_listener = TcpListener::bind(conf.listen_pg_addr.clone()).map_err(|e| { + let pg_listener = tcp_listener::bind(conf.listen_pg_addr.clone()).map_err(|e| { error!("failed to bind to address {}: {}", conf.listen_pg_addr, e); e })?; diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index e9e12283b9..050a8dcda3 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -19,6 +19,7 @@ thiserror = "1.0" tokio = "1.11" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +nix = "0.23.0" zenith_metrics = { path = "../zenith_metrics" } workspace_hack = { path = "../workspace_hack" } diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index 96b3cf5066..912d8308e5 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -40,3 +40,6 @@ pub mod logging; // Misc pub mod accum; + +// Utility for binding TcpListeners with proper socket options. +pub mod tcp_listener; diff --git a/zenith_utils/src/tcp_listener.rs b/zenith_utils/src/tcp_listener.rs new file mode 100644 index 0000000000..7666ad138c --- /dev/null +++ b/zenith_utils/src/tcp_listener.rs @@ -0,0 +1,16 @@ +use std::{ + io, + net::{TcpListener, ToSocketAddrs}, + os::unix::prelude::AsRawFd, +}; + +use nix::sys::socket::{setsockopt, sockopt::ReuseAddr}; + +/// Bind a [`TcpListener`] to addr with `SO_REUSEADDR` set to true. +pub fn bind(addr: A) -> io::Result { + let listener = TcpListener::bind(addr)?; + + setsockopt(listener.as_raw_fd(), ReuseAddr, &true)?; + + Ok(listener) +}