diff --git a/Cargo.lock b/Cargo.lock index c18bddfbcb..97515ca24d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -639,6 +639,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -1010,9 +1016,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" dependencies = [ "libc", ] @@ -1192,15 +1198,15 @@ dependencies = [ [[package]] name = "dashmap" -version = "5.4.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d" dependencies = [ "cfg-if", - "hashbrown 0.12.3", + "hashbrown 0.14.0", "lock_api", "once_cell", - "parking_lot_core 0.9.7", + "parking_lot_core 0.9.8", ] [[package]] @@ -1649,6 +1655,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "hashlink" version = "0.8.2" @@ -2073,9 +2085,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "lock_api" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" dependencies = [ "autocfg", "scopeguard", @@ -2339,9 +2351,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" @@ -2640,7 +2652,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.7", + "parking_lot_core 0.9.8", ] [[package]] @@ -2659,15 +2671,26 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.2.16", + "redox_syscall 0.3.5", "smallvec", - "windows-sys 0.45.0", + "windows-targets 0.48.0", +] + +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", ] [[package]] @@ -2678,6 +2701,8 @@ checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31" dependencies = [ "digest", "hmac", + "password-hash", + "sha2", ] [[package]] @@ -3056,6 +3081,7 @@ dependencies = [ "chrono", "clap", "consumption_metrics", + "dashmap", "futures", "git-version", "hashbrown 0.13.2", diff --git a/Cargo.toml b/Cargo.toml index a0acc061fb..5eab28e2e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ comfy-table = "6.1" const_format = "0.2" crc32c = "0.6" crossbeam-utils = "0.8.5" +dashmap = "5.5.0" either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" @@ -88,7 +89,7 @@ opentelemetry = "0.19.0" opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] } opentelemetry-semantic-conventions = "0.11.0" parking_lot = "0.12" -pbkdf2 = "0.12.1" +pbkdf2 = { version = "0.12.1", features = ["simple", "std"] } pin-project-lite = "0.2" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 438dd62315..cbab0c6f07 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -13,6 +13,7 @@ bytes = { workspace = true, features = ["serde"] } chrono.workspace = true clap.workspace = true consumption_metrics.workspace = true +dashmap.workspace = true futures.workspace = true git-version.workspace = true hashbrown.workspace = true @@ -29,7 +30,7 @@ metrics.workspace = true once_cell.workspace = true opentelemetry.workspace = true parking_lot.workspace = true -pbkdf2.workspace = true +pbkdf2 = { workspace = true, features = ["simple", "std"] } pin-project-lite.workspace = true postgres_backend.workspace = true pq_proto.workspace = true diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs index 703632a511..9bba846d57 100644 --- a/proxy/src/http/conn_pool.rs +++ b/proxy/src/http/conn_pool.rs @@ -1,8 +1,14 @@ use anyhow::Context; use async_trait::async_trait; -use parking_lot::Mutex; +use dashmap::DashMap; +use parking_lot::RwLock; +use pbkdf2::{ + password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString}, + Params, Pbkdf2, +}; use pq_proto::StartupMessageParams; use std::fmt; +use std::sync::atomic::{self, AtomicUsize}; use std::{collections::HashMap, sync::Arc}; use tokio::time; @@ -46,19 +52,40 @@ struct ConnPoolEntry { _last_access: std::time::Instant, } -// Per-endpoint connection pool, (dbname, username) -> Vec +// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub struct EndpointConnPool { - pools: HashMap<(String, String), Vec>, + pools: HashMap<(String, String), DbUserConnPool>, total_conns: usize, } +/// This is cheap and not hugely secure. +/// But probably good enough for in memory only hashes. +/// +/// Still takes 3.5ms to hash on my hardware. +/// We don't want to ruin the latency improvements of using the pool by making password verification take too long +const PARAMS: Params = Params { + rounds: 10_000, + output_length: 32, +}; + +#[derive(Default)] +pub struct DbUserConnPool { + conns: Vec, + password_hash: Option, +} + pub struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: Mutex>>>, + global_pool: DashMap>>, + + /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. + /// That seems like far too much effort, so we're using a relaxed increment counter instead. + /// It's only used for diagnostics. + global_pool_size: AtomicUsize, // Maximum number of connections per one endpoint. // Can mix different (dbname, username) connections. @@ -72,7 +99,8 @@ pub struct GlobalConnPool { impl GlobalConnPool { pub fn new(config: &'static crate::config::ProxyConfig) -> Arc { Arc::new(Self { - global_pool: Mutex::new(HashMap::new()), + global_pool: DashMap::new(), + global_pool_size: AtomicUsize::new(0), max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT, proxy_config: config, }) @@ -85,33 +113,92 @@ impl GlobalConnPool { ) -> anyhow::Result { let mut client: Option = None; + let mut hash_valid = false; if !force_new { - let pool = self.get_endpoint_pool(&conn_info.hostname).await; + let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let mut hash = None; // find a pool entry by (dbname, username) if exists - let mut pool = pool.lock(); - let pool_entries = pool.pools.get_mut(&conn_info.db_and_user()); - if let Some(pool_entries) = pool_entries { - if let Some(entry) = pool_entries.pop() { - client = Some(entry.conn); - pool.total_conns -= 1; + { + let pool = pool.read(); + if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) { + if !pool_entries.conns.is_empty() { + hash = pool_entries.password_hash.clone(); + } + } + } + + // a connection exists in the pool, verify the password hash + if let Some(hash) = hash { + let pw = conn_info.password.clone(); + let validate = tokio::task::spawn_blocking(move || { + Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash()) + }) + .await?; + + // if the hash is invalid, don't error + // we will continue with the regular connection flow + if validate.is_ok() { + hash_valid = true; + let mut pool = pool.write(); + if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) { + if let Some(entry) = pool_entries.conns.pop() { + client = Some(entry.conn); + pool.total_conns -= 1; + } + } } } } // ok return cached connection if found and establish a new one otherwise - if let Some(client) = client { + let new_client = if let Some(client) = client { if client.is_closed() { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); connect_to_compute(self.proxy_config, conn_info).await } else { info!("pool: reusing connection '{conn_info}'"); - Ok(client) + return Ok(client); } } else { info!("pool: opening a new connection '{conn_info}'"); connect_to_compute(self.proxy_config, conn_info).await + }; + + match &new_client { + // clear the hash. it's no longer valid + // TODO: update tokio-postgres fork to allow access to this error kind directly + Err(err) + if hash_valid && err.to_string().contains("password authentication failed") => + { + let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let mut pool = pool.write(); + if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) { + entry.password_hash = None; + } + } + // new password is valid and we should insert/update it + Ok(_) if !force_new && !hash_valid => { + let pw = conn_info.password.clone(); + let new_hash = tokio::task::spawn_blocking(move || { + let salt = SaltString::generate(rand::rngs::OsRng); + Pbkdf2 + .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt) + .map(|s| s.serialize()) + }) + .await??; + + let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let mut pool = pool.write(); + pool.pools + .entry(conn_info.db_and_user()) + .or_default() + .password_hash = Some(new_hash); + } + _ => {} } + + new_client } pub async fn put( @@ -119,33 +206,31 @@ impl GlobalConnPool { conn_info: &ConnInfo, client: tokio_postgres::Client, ) -> anyhow::Result<()> { - let pool = self.get_endpoint_pool(&conn_info.hostname).await; + let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); // return connection to the pool - let mut total_conns; let mut returned = false; let mut per_db_size = 0; - { - let mut pool = pool.lock(); - total_conns = pool.total_conns; + let total_conns = { + let mut pool = pool.write(); - let pool_entries: &mut Vec = pool - .pools - .entry(conn_info.db_and_user()) - .or_insert_with(|| Vec::with_capacity(1)); - if total_conns < self.max_conns_per_endpoint { - pool_entries.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); + if pool.total_conns < self.max_conns_per_endpoint { + // we create this db-user entry in get, so it should not be None + if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) { + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); - total_conns += 1; - returned = true; - per_db_size = pool_entries.len(); + returned = true; + per_db_size = pool_entries.conns.len(); - pool.total_conns += 1; + pool.total_conns += 1; + } } - } + + pool.total_conns + }; // do logging outside of the mutex if returned { @@ -157,25 +242,35 @@ impl GlobalConnPool { Ok(()) } - async fn get_endpoint_pool(&self, endpoint: &String) -> Arc> { + fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc> { + // fast path + if let Some(pool) = self.global_pool.get(endpoint) { + return pool.clone(); + } + + // slow path + let new_pool = Arc::new(RwLock::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + })); + // find or create a pool for this endpoint let mut created = false; - let mut global_pool = self.global_pool.lock(); - let pool = global_pool + let pool = self + .global_pool .entry(endpoint.clone()) .or_insert_with(|| { created = true; - Arc::new(Mutex::new(EndpointConnPool { - pools: HashMap::new(), - total_conns: 0, - })) + new_pool }) .clone(); - let global_pool_size = global_pool.len(); - drop(global_pool); // log new global pool size if created { + let global_pool_size = self + .global_pool_size + .fetch_add(1, atomic::Ordering::Relaxed) + + 1; info!( "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" ); diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index 82e78796c6..33375e63e9 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -44,6 +44,7 @@ const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); +static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level"); static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable"); @@ -193,7 +194,7 @@ pub async fn handle( let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); // Allow connection pooling only if explicitly requested - let allow_pool = false; + let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); // isolation level, read only and deferrable diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 62c5bd9ba9..61cd169fa3 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1769,6 +1769,15 @@ class VanillaPostgres(PgProtocol): with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file: conf_file.write("\n".join(options)) + def edit_hba(self, hba: List[str]): + """Prepend hba lines into pg_hba.conf file.""" + assert not self.running + with open(os.path.join(self.pgdatadir, "pg_hba.conf"), "r+") as conf_file: + data = conf_file.read() + conf_file.seek(0) + conf_file.write("\n".join(hba) + "\n") + conf_file.write(data) + def start(self, log_path: Optional[str] = None): assert not self.running self.running = True @@ -2166,15 +2175,18 @@ def static_proxy( ) -> Iterator[NeonProxy]: """Neon proxy that routes directly to vanilla postgres.""" - # For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql` - vanilla_pg.start() - vanilla_pg.safe_psql("create user proxy with login superuser password 'password'") - port = vanilla_pg.default_options["port"] host = vanilla_pg.default_options["host"] dbname = vanilla_pg.default_options["dbname"] auth_endpoint = f"postgres://proxy:password@{host}:{port}/{dbname}" + # require password for 'http_auth' user + vanilla_pg.edit_hba([f"host {dbname} http_auth {host} password"]) + + # For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql` + vanilla_pg.start() + vanilla_pg.safe_psql("create user proxy with login superuser password 'password'") + proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() http_port = port_distributor.get_port() diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index dd767e14b7..598a1bd084 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -340,3 +340,50 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): assert headers["Neon-Batch-Deferrable"] == "true" assert result[0]["rows"] == [{"answer": 42}] + + +def test_sql_over_http_pool(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + def get_pid(status: int, pw: str) -> Any: + connstr = ( + f"postgresql://http_auth:{pw}@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + ) + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps( + {"query": "SELECT pid FROM pg_stat_activity WHERE state = 'active'", "params": []} + ), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Pool-Opt-In": "true", + }, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == status + return response.json() + + pid1 = get_pid(200, "http")["rows"][0]["pid"] + + # query should be on the same connection + rows = get_pid(200, "http")["rows"] + assert rows == [{"pid": pid1}] + + # incorrect password should not work + res = get_pid(400, "foobar") + assert "password authentication failed for user" in res["message"] + + static_proxy.safe_psql("alter user http_auth with password 'http2'") + + # after password change, should open a new connection to verify it + pid2 = get_pid(200, "http2")["rows"][0]["pid"] + assert pid1 != pid2 + + # query should be on an existing connection + pid = get_pid(200, "http2")["rows"][0]["pid"] + assert pid in [pid1, pid2] + + # old password should not work + res = get_pid(400, "http") + assert "password authentication failed for user" in res["message"]