From e12e2681e97d997924383e3785bd734c9be8c21a Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:14:33 +0100 Subject: [PATCH] IP allowlist on the proxy side (#5906) ## Problem Per-project IP allowlist: https://github.com/neondatabase/cloud/issues/8116 ## Summary of changes Implemented IP filtering on the proxy side. To retrieve ip allowlist for all scenarios, added `get_auth_info` call to the control plane for: * sql-over-http * password_hack * cleartext_hack Added cache with ttl for sql-over-http path This might slow down a bit, consider using redis in the future. --------- Co-authored-by: Conrad Ludgate --- Cargo.lock | 5 +- Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/auth.rs | 13 +- proxy/src/auth/backend.rs | 52 ++++- proxy/src/auth/backend/classic.rs | 22 +- proxy/src/auth/credentials.rs | 197 +++++++++++++++--- proxy/src/bin/proxy.rs | 29 ++- proxy/src/cache.rs | 16 +- proxy/src/config.rs | 2 + proxy/src/console.rs | 2 +- proxy/src/console/provider.rs | 20 +- proxy/src/console/provider/mock.rs | 95 ++++++--- proxy/src/console/provider/neon.rs | 53 ++++- proxy/src/http.rs | 10 +- proxy/src/proxy.rs | 46 +++- proxy/src/proxy/tests.rs | 4 + proxy/src/serverless.rs | 17 +- proxy/src/serverless/conn_pool.rs | 34 ++- proxy/src/serverless/sql_over_http.rs | 16 +- proxy/src/serverless/websocket.rs | 3 + test_runner/fixtures/neon_fixtures.py | 4 + test_runner/regress/test_proxy_allowed_ips.py | 74 +++++++ 23 files changed, 601 insertions(+), 115 deletions(-) create mode 100644 test_runner/regress/test_proxy_allowed_ips.py diff --git a/Cargo.lock b/Cargo.lock index 3f0f21eb4a..6546590f6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2382,9 +2382,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.7.2" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" @@ -3612,6 +3612,7 @@ dependencies = [ "humantime", "hyper", "hyper-tungstenite", + "ipnet", "itertools", "md5", "metrics", diff --git a/Cargo.toml b/Cargo.toml index 28b58179ea..cbcb25359d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ humantime-serde = "1.1.1" hyper = "0.14" hyper-tungstenite = "0.11" inotify = "0.10.2" +ipnet = "2.9.0" itertools = "0.10" jsonwebtoken = "8" libc = "0.2" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 39a9c3ddb0..0822718bae 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -24,6 +24,7 @@ hostname.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true +ipnet.workspace = true itertools.workspace = true md5.workspace = true metrics.workspace = true diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 58dceb3bb6..7d79d34045 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -4,7 +4,7 @@ pub mod backend; pub use backend::BackendType; mod credentials; -pub use credentials::ClientCredentials; +pub use credentials::{check_peer_addr_is_in_list, ClientCredentials}; mod password_hack; pub use password_hack::parse_endpoint_param; @@ -56,6 +56,12 @@ pub enum AuthErrorImpl { /// Errors produced by e.g. [`crate::stream::PqStream`]. #[error(transparent)] Io(#[from] io::Error), + + #[error( + "This IP address is not allowed to connect to this endpoint. \ + Please add it to the allowed list in the Neon console." + )] + IpAddressNotAllowed, } #[derive(Debug, Error)] @@ -70,6 +76,10 @@ impl AuthError { pub fn auth_failed(user: impl Into>) -> Self { AuthErrorImpl::AuthFailed(user.into()).into() } + + pub fn ip_address_not_allowed() -> Self { + AuthErrorImpl::IpAddressNotAllowed.into() + } } impl> From for AuthError { @@ -91,6 +101,7 @@ impl UserFacingError for AuthError { MalformedPassword(_) => self.to_string(), MissingEndpointName => self.to_string(), Io(_) => "Internal error".to_string(), + IpAddressNotAllowed => self.to_string(), } } } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index f0197cc31b..aa872285b1 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -5,7 +5,12 @@ mod link; pub use link::LinkAuthError; use tokio_postgres::config::AuthKeys; +use crate::auth::credentials::check_peer_addr_is_in_list; +use crate::console::errors::GetAuthInfoError; +use crate::console::provider::AuthInfo; +use crate::console::AuthSecret; use crate::proxy::{handle_try_wake, retry_after, LatencyTimer}; +use crate::scram; use crate::stream::Stream; use crate::{ auth::{self, ClientCredentials}, @@ -20,6 +25,7 @@ use crate::{ use futures::TryFutureExt; use std::borrow::Cow; use std::ops::ControlFlow; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; @@ -64,6 +70,7 @@ pub enum BackendType<'a, T> { pub trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; + fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError>; } impl std::fmt::Display for BackendType<'_, ()> { @@ -140,14 +147,38 @@ async fn auth_quirks_creds( // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. - if creds.project.is_none() { + let maybe_success = if creds.project.is_none() { // Password will be checked by the compute node later. - return hacks::password_hack(creds, client, latency_timer).await; - } + Some(hacks::password_hack(creds, client, latency_timer).await?) + } else { + None + }; // Password hack should set the project name. // TODO: make `creds.project` more type-safe. assert!(creds.project.is_some()); + info!("fetching user's authentication info"); + // TODO(anna): this will slow down both "hacks" below; we probably need a cache. + let AuthInfo { + secret, + allowed_ips, + } = api.get_auth_info(extra, creds).await?; + + // check allowed list + if !check_peer_addr_is_in_list(&creds.peer_addr.ip(), &allowed_ips) { + return Err(auth::AuthError::ip_address_not_allowed()); + } + let secret = secret.unwrap_or_else(|| { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthSecret::Scram(scram::ServerSecret::mock(creds.user, rand::random())) + }); + + if let Some(success) = maybe_success { + return Ok(success); + } // Perform cleartext auth if we're allowed to do that. // Currently, we use it for websocket connections (latency). @@ -157,7 +188,7 @@ async fn auth_quirks_creds( } // Finally, proceed with the main auth flow (SCRAM-based). - classic::authenticate(api, extra, creds, client, config, latency_timer).await + classic::authenticate(creds, client, config, latency_timer, secret).await } /// True to its name, this function encapsulates our current auth trade-offs. @@ -305,6 +336,19 @@ impl BackendType<'_, ClientCredentials<'_>> { Ok(res) } + pub async fn get_allowed_ips( + &self, + extra: &ConsoleReqExtra<'_>, + ) -> Result>, GetAuthInfoError> { + use BackendType::*; + match self { + Console(api, creds) => api.get_allowed_ips(extra, creds).await, + Postgres(api, creds) => api.get_allowed_ips(extra, creds).await, + Link(_) => Ok(Arc::new(vec![])), + Test(x) => x.get_allowed_ips(), + } + } + /// When applicable, wake the compute node, gaining its connection info in the process. /// The link auth flow doesn't support this, so we return [`None`] in that case. pub async fn wake_compute( diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index ac0d490db1..bb210821cd 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -3,38 +3,28 @@ use crate::{ auth::{self, AuthFlow, ClientCredentials}, compute, config::AuthenticationConfig, - console::{self, AuthInfo, ConsoleReqExtra}, + console::AuthSecret, proxy::LatencyTimer, - sasl, scram, + sasl, stream::{PqStream, Stream}, }; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( - api: &impl console::Api, - extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials<'_>, client: &mut PqStream>, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, + secret: AuthSecret, ) -> auth::Result> { - info!("fetching user's authentication info"); - let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| { - // If we don't have an authentication secret, we mock one to - // prevent malicious probing (possible due to missing protocol steps). - // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); - AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) - }); - let flow = AuthFlow::new(client); - let scram_keys = match info { - AuthInfo::Md5(_) => { + let scram_keys = match secret { + AuthSecret::Md5(_) => { info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); } - AuthInfo::Scram(secret) => { + AuthSecret::Scram(secret) => { info!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 9fe9c26f0c..facb8da8cd 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -7,9 +7,12 @@ use crate::{ }; use itertools::Itertools; use pq_proto::StartupMessageParams; -use std::collections::HashSet; +use std::{ + collections::HashSet, + net::{IpAddr, SocketAddr}, +}; use thiserror::Error; -use tracing::info; +use tracing::{info, warn}; #[derive(Debug, Error, PartialEq, Eq, Clone)] pub enum ClientCredsParseError { @@ -44,6 +47,7 @@ pub struct ClientCredentials<'a> { pub project: Option, pub cache_key: String, + pub peer_addr: SocketAddr, } impl ClientCredentials<'_> { @@ -54,19 +58,11 @@ impl ClientCredentials<'_> { } impl<'a> ClientCredentials<'a> { - #[cfg(test)] - pub fn new_noop() -> Self { - ClientCredentials { - user: "", - project: None, - cache_key: "".to_string(), - } - } - pub fn parse( params: &'a StartupMessageParams, sni: Option<&str>, common_names: Option>, + peer_addr: SocketAddr, ) -> Result { use ClientCredsParseError::*; @@ -153,10 +149,59 @@ impl<'a> ClientCredentials<'a> { user, project, cache_key, + peer_addr, }) } } +pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec) -> bool { + if ip_list.is_empty() { + return true; + } + for ip in ip_list { + // We expect that all ip addresses from control plane are correct. + // However, if some of them are broken, we still can check the others. + match parse_ip_pattern(ip) { + Ok(pattern) => { + if check_ip(peer_addr, &pattern) { + return true; + } + } + Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err), + } + } + false +} + +#[derive(Debug, Clone, Eq, PartialEq)] +enum IpPattern { + Subnet(ipnet::IpNet), + Range(IpAddr, IpAddr), + Single(IpAddr), +} + +fn parse_ip_pattern(pattern: &str) -> anyhow::Result { + if pattern.contains('/') { + let subnet: ipnet::IpNet = pattern.parse()?; + return Ok(IpPattern::Subnet(subnet)); + } + if let Some((start, end)) = pattern.split_once('-') { + let start: IpAddr = start.parse()?; + let end: IpAddr = end.parse()?; + return Ok(IpPattern::Range(start, end)); + } + let addr: IpAddr = pattern.parse()?; + Ok(IpPattern::Single(addr)) +} + +fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool { + match pattern { + IpPattern::Subnet(subnet) => subnet.contains(ip), + IpPattern::Range(start, end) => start <= ip && ip <= end, + IpPattern::Single(addr) => addr == ip, + } +} + fn project_name_valid(name: &str) -> bool { name.chars().all(|c| c.is_alphanumeric() || c == '-') } @@ -176,8 +221,8 @@ mod tests { fn parse_bare_minimum() -> anyhow::Result<()> { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); - - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -191,8 +236,8 @@ mod tests { ("database", "world"), // should be ignored ("foo", "bar"), // should be ignored ]); - - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -206,7 +251,8 @@ mod tests { let sni = Some("foo.localhost"); let common_names = Some(["localhost".into()].into()); - let creds = ClientCredentials::parse(&options, sni, common_names)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("foo")); assert_eq!(creds.cache_key, "foo"); @@ -221,7 +267,8 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -235,7 +282,8 @@ mod tests { ("options", "-ckey=1 endpoint=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -252,7 +300,8 @@ mod tests { ), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert!(creds.project.is_none()); @@ -266,7 +315,8 @@ mod tests { ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert!(creds.project.is_none()); @@ -280,7 +330,8 @@ mod tests { let sni = Some("baz.localhost"); let common_names = Some(["localhost".into()].into()); - let creds = ClientCredentials::parse(&options, sni, common_names)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -293,12 +344,14 @@ mod tests { let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.a.com"); - let creds = ClientCredentials::parse(&options, sni, common_names)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("p1")); let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.b.com"); - let creds = ClientCredentials::parse(&options, sni, common_names)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("p1")); Ok(()) @@ -312,7 +365,9 @@ mod tests { let sni = Some("second.localhost"); let common_names = Some(["localhost".into()].into()); - let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail"); + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let err = ClientCredentials::parse(&options, sni, common_names, peer_addr) + .expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -329,7 +384,9 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["example.com".into()].into()); - let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail"); + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let err = ClientCredentials::parse(&options, sni, common_names, peer_addr) + .expect_err("should fail"); match err { UnknownCommonName { cn } => { assert_eq!(cn, "localhost"); @@ -347,7 +404,8 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["localhost".into()].into()); - let creds = ClientCredentials::parse(&options, sni, common_names)?; + let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("project")); assert_eq!( creds.cache_key, @@ -356,4 +414,91 @@ mod tests { Ok(()) } + + #[test] + fn test_check_peer_addr_is_in_list() { + let peer_addr = IpAddr::from([127, 0, 0, 1]); + assert!(check_peer_addr_is_in_list(&peer_addr, &vec![])); + assert!(check_peer_addr_is_in_list( + &peer_addr, + &vec!["127.0.0.1".into()] + )); + assert!(!check_peer_addr_is_in_list( + &peer_addr, + &vec!["8.8.8.8".into()] + )); + // If there is an incorrect address, it will be skipped. + assert!(check_peer_addr_is_in_list( + &peer_addr, + &vec!["88.8.8".into(), "127.0.0.1".into()] + )); + } + #[test] + fn test_parse_ip_v4() -> anyhow::Result<()> { + let peer_addr = IpAddr::from([127, 0, 0, 1]); + // Ok + assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr)); + assert_eq!( + parse_ip_pattern("127.0.0.1/31")?, + IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?) + ); + assert_eq!( + parse_ip_pattern("0.0.0.0-200.0.1.2")?, + IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2])) + ); + + // Error + assert!(parse_ip_pattern("300.0.1.2").is_err()); + assert!(parse_ip_pattern("30.1.2").is_err()); + assert!(parse_ip_pattern("127.0.0.1/33").is_err()); + assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err()); + assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err()); + Ok(()) + } + + #[test] + fn test_check_ipv4() -> anyhow::Result<()> { + let peer_addr = IpAddr::from([127, 0, 0, 1]); + let peer_addr_next = IpAddr::from([127, 0, 0, 2]); + let peer_addr_prev = IpAddr::from([127, 0, 0, 0]); + // Success + assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr))); + assert!(check_ip( + &peer_addr, + &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?) + )); + assert!(check_ip( + &peer_addr, + &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?) + )); + assert!(check_ip( + &peer_addr, + &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2])) + )); + assert!(check_ip( + &peer_addr, + &IpPattern::Range(peer_addr, peer_addr) + )); + + // Not success + assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev))); + assert!(!check_ip( + &peer_addr, + &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?) + )); + assert!(!check_ip( + &peer_addr, + &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev) + )); + assert!(!check_ip( + &peer_addr, + &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0])) + )); + // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false. + assert!(!check_ip( + &peer_addr, + &IpPattern::Range(peer_addr, peer_addr_prev) + )); + Ok(()) + } } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d90ac86a82..7457e26867 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,8 +1,11 @@ use futures::future::Either; use proxy::auth; use proxy::config::AuthenticationConfig; +use proxy::config::CacheOptions; use proxy::config::HttpConfig; use proxy::console; +use proxy::console::provider::AllowedIpsCache; +use proxy::console::provider::NodeInfoCache; use proxy::http; use proxy::rate_limiter::RateLimiterConfig; use proxy::usage_metrics; @@ -113,6 +116,12 @@ struct ProxyCliArgs { initial_limit: usize, #[clap(flatten)] aimd_config: proxy::rate_limiter::AimdConfig, + /// cache for `allowed_ips` (use `size=0` to disable) + #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)] + allowed_ips_cache: String, + /// disable ip check for http requests. If it is too time consuming, it could be turned off. + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + disable_ip_check_for_http: bool, } #[tokio::main] @@ -241,11 +250,24 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let auth_backend = match &args.auth_backend { AuthBackend::Console => { - let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?; + let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; + let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?; - info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}"); + info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); + info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}"); let caches = Box::leak(Box::new(console::caches::ApiCaches { - node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl), + node_info: NodeInfoCache::new( + "node_info_cache", + wake_compute_cache_config.size, + wake_compute_cache_config.ttl, + true, + ), + allowed_ips: AllowedIpsCache::new( + "allowed_ips_cache", + allowed_ips_cache_config.size, + allowed_ips_cache_config.ttl, + false, + ), })); let config::WakeComputeLockOptions { @@ -292,6 +314,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { http_config, authentication_config, require_client_ip: args.require_client_ip, + disable_ip_check_for_http: args.disable_ip_check_for_http, })); Ok(config) diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index a9d6793bbd..f54f360b01 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -55,7 +55,7 @@ pub mod timed_lru { /// * Whenever a new entry is inserted, the least recently accessed one is evicted. /// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). /// - /// * When the entry is about to be retrieved, we check its expiration timestamp. + /// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. /// If the entry has expired, we remove it from the cache; Otherwise we bump the /// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong /// its existence. @@ -79,6 +79,8 @@ pub mod timed_lru { /// Default time-to-live of a single entry. ttl: Duration, + + update_ttl_on_retrieval: bool, } impl Cache for TimedLru { @@ -99,11 +101,17 @@ pub mod timed_lru { impl TimedLru { /// Construct a new LRU cache with timed entries. - pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self { + pub fn new( + name: &'static str, + capacity: usize, + ttl: Duration, + update_ttl_on_retrieval: bool, + ) -> Self { Self { name, cache: LruCache::new(capacity).into(), ttl, + update_ttl_on_retrieval, } } @@ -165,7 +173,9 @@ pub mod timed_lru { let (created_at, expires_at) = (entry.created_at, entry.expires_at); // Update the deadline and the entry's position in the LRU list. - raw_entry.get_mut().expires_at = deadline; + if self.update_ttl_on_retrieval { + raw_entry.get_mut().expires_at = deadline; + } raw_entry.to_back(); drop(cache); // drop lock before logging diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 89b432df92..182d71f9be 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -19,6 +19,7 @@ pub struct ProxyConfig { pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, pub require_client_ip: bool, + pub disable_ip_check_for_http: bool, } #[derive(Debug)] @@ -298,6 +299,7 @@ impl CertResolver { } /// Helper for cmdline cache options parsing. +#[derive(Debug)] pub struct CacheOptions { /// Max number of entries. pub size: usize, diff --git a/proxy/src/console.rs b/proxy/src/console.rs index 6da627389e..07bc807950 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -6,7 +6,7 @@ pub mod messages; /// Wrappers for console APIs and their mocks. pub mod provider; -pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo}; +pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo}; /// Various cache-related types. pub mod caches { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 54bcd1f081..a525de8e53 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -204,7 +204,7 @@ pub struct ConsoleReqExtra<'a> { } /// Auth secret which is managed by the cloud. -pub enum AuthInfo { +pub enum AuthSecret { /// Md5 hash of user's password. Md5([u8; 16]), @@ -212,6 +212,13 @@ pub enum AuthInfo { Scram(scram::ServerSecret), } +#[derive(Default)] +pub struct AuthInfo { + pub secret: Option, + /// List of IP addresses allowed for the autorization. + pub allowed_ips: Vec, +} + /// Info for establishing a connection to a compute node. /// This is what we get after auth succeeded, but not before! #[derive(Clone)] @@ -230,6 +237,7 @@ pub struct NodeInfo { pub type NodeInfoCache = TimedLru, NodeInfo>; pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; +pub type AllowedIpsCache = TimedLru, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -240,7 +248,13 @@ pub trait Api { &self, extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials, - ) -> Result, errors::GetAuthInfoError>; + ) -> Result; + + async fn get_allowed_ips( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials, + ) -> Result>, errors::GetAuthInfoError>; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( @@ -254,6 +268,8 @@ pub trait Api { pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub node_info: NodeInfoCache, + /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead. + pub allowed_ips: TimedLru, Arc>>, } /// Various caches for [`console`](super). diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 750a2d141e..4cc68f0ac1 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -1,14 +1,16 @@ //! Mock console backend which relies on a user-provided postgres instance. +use std::sync::Arc; + use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, - AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; -use tokio_postgres::config::SslMode; +use tokio_postgres::{config::SslMode, Client}; use tracing::{error, info, info_span, warn, Instrument}; #[derive(Debug, Error)] @@ -46,8 +48,8 @@ impl Api { async fn do_get_auth_info( &self, creds: &ClientCredentials<'_>, - ) -> Result, GetAuthInfoError> { - async { + ) -> Result { + let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to // write more code for reopening it if it got closed, which doesn't // seem worth it. @@ -55,32 +57,48 @@ impl Api { tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; tokio::spawn(connection); - let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; - let rows = client.query(query, &[&creds.user]).await?; - - // We can get at most one row, because `rolname` is unique. - let row = match rows.first() { - Some(row) => row, - // This means that the user doesn't exist, so there can be no secret. - // However, this is still a *valid* outcome which is very similar - // to getting `404 Not found` from the Neon console. + let secret = match get_execute_postgres_query( + &client, + "select rolpassword from pg_catalog.pg_authid where rolname = $1", + &[&creds.user], + "rolpassword", + ) + .await? + { + Some(entry) => { + info!("got a secret: {entry}"); // safe since it's not a prod scenario + let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); + secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + } None => { warn!("user '{}' does not exist", creds.user); - return Ok(None); + None } }; + let allowed_ips = match get_execute_postgres_query( + &client, + "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", + &[&creds.project.clone().unwrap_or_default().as_str()], + "allowed_ips", + ) + .await? + { + Some(s) => { + info!("got allowed_ips: {s}"); + s.split(',').map(String::from).collect() + } + None => vec![], + }; - let entry = row - .try_get("rolpassword") - .map_err(MockApiError::PasswordNotSet)?; - - info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram); - Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5))) + Ok((secret, allowed_ips)) } - .map_err(crate::error::log_error) + .map_err(crate::error::log_error::) .instrument(info_span!("postgres", url = self.endpoint.as_str())) - .await + .await?; + Ok(AuthInfo { + secret, + allowed_ips, + }) } async fn do_wake_compute(&self) -> Result { @@ -100,6 +118,27 @@ impl Api { } } +async fn get_execute_postgres_query( + client: &Client, + query: &str, + params: &[&(dyn tokio_postgres::types::ToSql + Sync)], + idx: &str, +) -> Result, GetAuthInfoError> { + let rows = client.query(query, params).await?; + + // We can get at most one row, because `rolname` is unique. + let row = match rows.first() { + Some(row) => row, + // This means that the user doesn't exist, so there can be no secret. + // However, this is still a *valid* outcome which is very similar + // to getting `404 Not found` from the Neon console. + None => return Ok(None), + }; + + let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?; + Ok(Some(entry)) +} + #[async_trait] impl super::Api for Api { #[tracing::instrument(skip_all)] @@ -107,10 +146,18 @@ impl super::Api for Api { &self, _extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials, - ) -> Result, GetAuthInfoError> { + ) -> Result { self.do_get_auth_info(creds).await } + async fn get_allowed_ips( + &self, + _extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials, + ) -> Result>, GetAuthInfoError> { + Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips)) + } + #[tracing::instrument(skip_all)] async fn wake_compute( &self, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 0dc7c71534..117d0ec190 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -3,11 +3,17 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, - ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, +}; +use crate::{ + auth::ClientCredentials, + compute, http, + proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, + scram, }; -use crate::{auth::ClientCredentials, compute, http, scram}; use async_trait::async_trait; use futures::TryFutureExt; +use itertools::Itertools; use std::{net::SocketAddr, sync::Arc}; use tokio::time::Instant; use tokio_postgres::config::SslMode; @@ -48,7 +54,7 @@ impl Api { &self, extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials<'_>, - ) -> Result, GetAuthInfoError> { + ) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); async { let request = self @@ -72,16 +78,25 @@ impl Api { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. Err(e) => match e.http_status_code() { - Some(http::StatusCode::NOT_FOUND) => return Ok(None), + Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()), _otherwise => return Err(e.into()), }, }; let secret = scram::ServerSecret::parse(&body.role_secret) - .map(AuthInfo::Scram) + .map(AuthSecret::Scram) .ok_or(GetAuthInfoError::BadSecret)?; - - Ok(Some(secret)) + let allowed_ips = body + .allowed_ips + .into_iter() + .flatten() + .map(String::from) + .collect_vec(); + ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64); + Ok(AuthInfo { + secret: Some(secret), + allowed_ips, + }) } .map_err(crate::error::log_error) .instrument(info_span!("http", id = request_id)) @@ -148,10 +163,32 @@ impl super::Api for Api { &self, extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials, - ) -> Result, GetAuthInfoError> { + ) -> Result { self.do_get_auth_info(extra, creds).await } + async fn get_allowed_ips( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials, + ) -> Result>, GetAuthInfoError> { + let key: &str = creds.project().expect("impossible"); + if let Some(allowed_ips) = self.caches.allowed_ips.get(key) { + ALLOWED_IPS_BY_CACHE_OUTCOME + .with_label_values(&["hit"]) + .inc(); + return Ok(Arc::new(allowed_ips.to_vec())); + } + ALLOWED_IPS_BY_CACHE_OUTCOME + .with_label_values(&["miss"]) + .inc(); + let allowed_ips = Arc::new(self.do_get_auth_info(extra, creds).await?.allowed_ips); + self.caches + .allowed_ips + .insert(key.into(), allowed_ips.clone()); + Ok(allowed_ips) + } + #[tracing::instrument(skip_all)] async fn wake_compute( &self, diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 159b949da3..638705d3e9 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -13,7 +13,7 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio::time::Instant; use tracing::trace; -use crate::{rate_limiter, url::ApiUrl}; +use crate::{proxy::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl}; use reqwest_middleware::RequestBuilder; /// This is the preferred way to create new http clients, @@ -90,7 +90,13 @@ impl Endpoint { /// Execute a [request](reqwest::Request). pub async fn execute(&self, request: Request) -> Result { - self.client.execute(request).await + let path = request.url().path().to_string(); + let start = Instant::now(); + let res = self.client.execute(request).await; + CONSOLE_REQUEST_LATENCY + .with_label_values(&[&path]) + .observe(start.elapsed().as_micros() as f64); + res } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9560c8546a..2af2dd5562 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -24,7 +24,7 @@ use prometheus::{ IntGaugeVec, }; use regex::Regex; -use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant}; +use std::{error::Error, io, net::SocketAddr, ops::ControlFlow, sync::Arc, time::Instant}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, time, @@ -110,12 +110,34 @@ static COMPUTE_CONNECTION_LATENCY: Lazy = Lazy::new(|| { .unwrap() }); +pub static CONSOLE_REQUEST_LATENCY: Lazy = Lazy::new(|| { + register_histogram_vec!( + "proxy_console_request_latency", + "Time it took for proxy to establish a connection to the compute endpoint", + // proxy_wake_compute/proxy_get_role_info + &["request"], + // largest bucket = 2^16 * 0.2ms = 13s + exponential_buckets(0.2, 2.0, 16).unwrap(), + ) + .unwrap() +}); + +pub static ALLOWED_IPS_BY_CACHE_OUTCOME: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_allowed_ips_cache_misses", + "Number of cache hits/misses for allowed ips", + // hit/miss + &["outcome"], + ) + .unwrap() +}); + pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy = Lazy::new(|| { register_histogram!( "semaphore_control_plane_token_acquire_seconds", "Time it took for proxy to establish a connection to the compute endpoint", - // largest bucket = 2^16 * 0.5ms = 32s - exponential_buckets(0.0005, 2.0, 16).unwrap(), + // largest bucket = 3^16 * 0.00005s = 3.28s + exponential_buckets(0.00005, 3.0, 16).unwrap(), ) .unwrap() }); @@ -138,6 +160,15 @@ pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy = Lazy::new(|| { .unwrap() }); +pub static ALLOWED_IPS_NUMBER: Lazy = Lazy::new(|| { + register_histogram!( + "proxy_allowed_ips_number", + "Number of allowed ips", + vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0], + ) + .unwrap() +}); + pub struct LatencyTimer { // time since the stopwatch was started start: Option, @@ -265,7 +296,7 @@ pub async fn task_main( loop { tokio::select! { accept_result = listener.accept() => { - let (socket, _) = accept_result?; + let (socket, peer_addr) = accept_result?; let session_id = uuid::Uuid::new_v4(); let cancel_map = Arc::clone(&cancel_map); @@ -274,7 +305,9 @@ pub async fn task_main( info!("accepted postgres client connection"); let mut socket = WithClientIp::new(socket); + let mut peer_addr = peer_addr; if let Some(ip) = socket.wait_for_addr().await? { + peer_addr = ip; tracing::Span::current().record("peer_addr", &tracing::field::display(ip)); } else if config.require_client_ip { bail!("missing required client IP"); @@ -285,7 +318,7 @@ pub async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await + handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr).await } .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty)) .unwrap_or_else(move |e| { @@ -375,6 +408,7 @@ pub async fn handle_client( session_id: uuid::Uuid, stream: S, mode: ClientMode, + peer_addr: SocketAddr, ) -> anyhow::Result<()> { info!( protocol = mode.protocol_label(), @@ -408,7 +442,7 @@ pub async fn handle_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names)) + .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names, peer_addr)) .transpose(); match result { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index de9cc0800b..b97c0efce4 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -466,6 +466,10 @@ impl TestBackend for TestConnectMechanism { x => panic!("expecting action {:?}, wake_compute is called instead", x), } } + + fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError> { + unimplemented!("not used in tests") + } } fn helper_create_cached_node_info() -> CachedNodeInfo { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 23deda3ae6..45f8132393 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -23,6 +23,7 @@ use hyper::{ Body, Method, Request, Response, }; +use std::net::SocketAddr; use std::task::Poll; use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; @@ -102,7 +103,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); request_handler( - req, config, conn_pool, cancel_map, session_id, sni_name, + req, config, conn_pool, cancel_map, session_id, sni_name, peer_addr, ) .instrument(info_span!( "serverless", @@ -170,6 +171,7 @@ async fn request_handler( cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, + peer_addr: SocketAddr, ) -> Result, ApiError> { let host = request .headers() @@ -187,9 +189,15 @@ async fn request_handler( tokio::spawn( async move { - if let Err(e) = - websocket::serve_websocket(websocket, config, &cancel_map, session_id, host) - .await + if let Err(e) = websocket::serve_websocket( + websocket, + config, + &cancel_map, + session_id, + host, + peer_addr, + ) + .await { error!(session_id = ?session_id, "error in websocket connection: {e:#}"); } @@ -205,6 +213,7 @@ async fn request_handler( sni_hostname, conn_pool, session_id, + peer_addr, &config.http_config, ) .await diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index b753bc8918..2072cadc3a 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -8,7 +8,7 @@ use pbkdf2::{ Params, Pbkdf2, }; use pq_proto::StartupMessageParams; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{ fmt, task::{ready, Poll}, @@ -21,7 +21,8 @@ use tokio::time; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ - auth, console, + auth::{self, check_peer_addr_is_in_list}, + console, proxy::{ neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER, @@ -144,6 +145,7 @@ impl GlobalConnPool { conn_info: &ConnInfo, force_new: bool, session_id: uuid::Uuid, + peer_addr: SocketAddr, ) -> anyhow::Result { let mut client: Option = None; let mut latency_timer = LatencyTimer::new("http"); @@ -203,6 +205,7 @@ impl GlobalConnPool { conn_id, session_id, latency_timer, + peer_addr, ) .await } else { @@ -225,6 +228,7 @@ impl GlobalConnPool { conn_id, session_id, latency_timer, + peer_addr, ) .await }; @@ -401,6 +405,7 @@ async fn connect_to_compute( conn_id: uuid::Uuid, session_id: uuid::Uuid, latency_timer: LatencyTimer, + peer_addr: SocketAddr, ) -> anyhow::Result { let tls = config.tls_config.as_ref(); let common_names = tls.and_then(|tls| tls.common_names.clone()); @@ -411,12 +416,13 @@ async fn connect_to_compute( ("application_name", APP_NAME), ("options", conn_info.options.as_deref().unwrap_or("")), ]); - - let creds = config - .auth_backend - .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, Some(&conn_info.hostname), common_names)) - .transpose()?; + let creds = auth::ClientCredentials::parse( + ¶ms, + Some(&conn_info.hostname), + common_names, + peer_addr, + )?; + let backend = config.auth_backend.as_ref().map(|_| creds); let console_options = neon_options(¶ms); @@ -425,8 +431,14 @@ async fn connect_to_compute( application_name: Some(APP_NAME), options: console_options.as_deref(), }; - - let node_info = creds + // TODO(anna): this is a bit hacky way, consider using console notification listener. + if !config.disable_ip_check_for_http { + let allowed_ips = backend.get_allowed_ips(&extra).await?; + if !check_peer_addr_is_in_list(&peer_addr.ip(), &allowed_ips) { + return Err(auth::AuthError::ip_address_not_allowed().into()); + } + } + let node_info = backend .wake_compute(&extra) .await? .context("missing cache entry from wake_compute")?; @@ -439,7 +451,7 @@ async fn connect_to_compute( }, node_info, &extra, - &creds, + &backend, latency_timer, ) .await diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 2df2be1d3d..25b96668de 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,3 +1,4 @@ +use std::net::SocketAddr; use std::sync::Arc; use anyhow::bail; @@ -201,11 +202,19 @@ pub async fn handle( sni_hostname: Option, conn_pool: Arc, session_id: uuid::Uuid, + peer_addr: SocketAddr, config: &'static HttpConfig, ) -> Result, ApiError> { let result = tokio::time::timeout( config.timeout, - handle_inner(config, request, sni_hostname, conn_pool, session_id), + handle_inner( + config, + request, + sni_hostname, + conn_pool, + session_id, + peer_addr, + ), ) .await; let mut response = match result { @@ -292,6 +301,7 @@ async fn handle_inner( sni_hostname: Option, conn_pool: Arc, session_id: uuid::Uuid, + peer_addr: SocketAddr, ) -> anyhow::Result> { NUM_CONNECTIONS_ACCEPTED_COUNTER .with_label_values(&["http"]) @@ -351,7 +361,9 @@ async fn handle_inner( let body = hyper::body::to_bytes(request.into_body()).await?; let payload: Payload = serde_json::from_slice(&body)?; - let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?; + let mut client = conn_pool + .get(&conn_info, !allow_pool, session_id, peer_addr) + .await?; let mut response = Response::builder() .status(StatusCode::OK) diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 86141ab64f..8fb9a3dee4 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -11,6 +11,7 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; use std::{ + net::SocketAddr, pin::Pin, task::{ready, Context, Poll}, }; @@ -132,6 +133,7 @@ pub async fn serve_websocket( cancel_map: &CancelMap, session_id: uuid::Uuid, hostname: Option, + peer_addr: SocketAddr, ) -> anyhow::Result<()> { let websocket = websocket.await?; handle_client( @@ -140,6 +142,7 @@ pub async fn serve_websocket( session_id, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, + peer_addr, ) .await?; Ok(()) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 23a36ad6c9..862aab84dc 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2390,6 +2390,10 @@ def static_proxy( # For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql` vanilla_pg.start() vanilla_pg.safe_psql("create user proxy with login superuser password 'password'") + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane") + vanilla_pg.safe_psql( + "CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))" + ) proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() diff --git a/test_runner/regress/test_proxy_allowed_ips.py b/test_runner/regress/test_proxy_allowed_ips.py new file mode 100644 index 0000000000..f533579811 --- /dev/null +++ b/test_runner/regress/test_proxy_allowed_ips.py @@ -0,0 +1,74 @@ +import psycopg2 +import pytest +from fixtures.neon_fixtures import ( + NeonProxy, + VanillaPostgres, +) + +TABLE_NAME = "neon_control_plane.endpoints" + + +# Proxy uses the same logic for psql and websockets. +@pytest.mark.asyncio +async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres): + # Shouldn't be able to connect to this project + vanilla_pg.safe_psql( + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')" + ) + # Should be able to connect to this project + vanilla_pg.safe_psql( + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')" + ) + + def check_cannot_connect(**kwargs): + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.safe_psql(**kwargs) + text = str(exprinfo.value).strip() + assert "This IP address is not allowed to connect" in text + + # no SNI, deprecated `options=project` syntax (before we had several endpoint in project) + check_cannot_connect(query="select 1", sslsni=0, options="project=private-project") + + # no SNI, new `options=endpoint` syntax + check_cannot_connect(query="select 1", sslsni=0, options="endpoint=private-project") + + # with SNI + check_cannot_connect(query="select 1", host="private-project.localtest.me") + + # no SNI, deprecated `options=project` syntax (before we had several endpoint in project) + out = static_proxy.safe_psql(query="select 1", sslsni=0, options="project=generic-project") + assert out[0][0] == 1 + + # no SNI, new `options=endpoint` syntax + out = static_proxy.safe_psql(query="select 1", sslsni=0, options="endpoint=generic-project") + assert out[0][0] == 1 + + # with SNI + out = static_proxy.safe_psql(query="select 1", host="generic-project.localtest.me") + assert out[0][0] == 1 + + +@pytest.mark.asyncio +async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + # Shouldn't be able to connect to this project + vanilla_pg.safe_psql( + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')" + ) + + def query(status: int, query: str, *args): + static_proxy.http_query( + query, + args, + user="http_auth", + password="http", + expected_code=status, + ) + + query(400, "select 1;") # ip address is not allowed + # Should be able to connect to this project + vanilla_pg.safe_psql( + f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'" + ) + query(200, "select 1;") # should work now