diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 0707c1331f..8d1b861a66 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -4,7 +4,9 @@ pub mod backend; pub use backend::BackendType; mod credentials; -pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint}; +pub use credentials::{ + check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern, +}; mod password_hack; pub use password_hack::parse_endpoint_param; diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index a6164f7bfb..1e03510119 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -35,6 +35,8 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; +use super::IpPattern; + /// This type serves two purposes: /// /// * When `T` is `()`, it's just a regular auth backend selector @@ -55,7 +57,7 @@ pub enum BackendType<'a, T> { pub trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError>; + fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError>; } impl std::fmt::Display for BackendType<'_, ()> { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index ada7f3614c..342fd6fce9 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -7,7 +7,7 @@ use crate::{ use itertools::Itertools; use pq_proto::StartupMessageParams; use smol_str::SmolStr; -use std::{collections::HashSet, net::IpAddr}; +use std::{collections::HashSet, net::IpAddr, str::FromStr}; use thiserror::Error; use tracing::{info, warn}; @@ -151,30 +151,51 @@ impl ComputeUserInfoMaybeEndpoint { } } -pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec) -> bool { - if ip_list.is_empty() { - return true; - } - for ip in ip_list { - // We expect that all ip addresses from control plane are correct. - // However, if some of them are broken, we still can check the others. - match parse_ip_pattern(ip) { - Ok(pattern) => { - if check_ip(peer_addr, &pattern) { - return true; - } - } - Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err), - } - } - false +pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { + ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern)) } #[derive(Debug, Clone, Eq, PartialEq)] -enum IpPattern { +pub enum IpPattern { Subnet(ipnet::IpNet), Range(IpAddr, IpAddr), Single(IpAddr), + None, +} + +impl<'de> serde::de::Deserialize<'de> for IpPattern { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct StrVisitor; + impl<'de> serde::de::Visitor<'de> for StrVisitor { + type Value = IpPattern; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(parse_ip_pattern(v).unwrap_or_else(|e| { + warn!("Cannot parse ip pattern {v}: {e}"); + IpPattern::None + })) + } + } + deserializer.deserialize_str(StrVisitor) + } +} + +impl FromStr for IpPattern { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + parse_ip_pattern(s) + } } fn parse_ip_pattern(pattern: &str) -> anyhow::Result { @@ -196,6 +217,7 @@ fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool { IpPattern::Subnet(subnet) => subnet.contains(ip), IpPattern::Range(start, end) => start <= ip && ip <= end, IpPattern::Single(addr) => addr == ip, + IpPattern::None => false, } } @@ -206,6 +228,7 @@ fn project_name_valid(name: &str) -> bool { #[cfg(test)] mod tests { use super::*; + use serde_json::json; use ComputeUserInfoParseError::*; #[test] @@ -415,21 +438,17 @@ mod tests { #[test] fn test_check_peer_addr_is_in_list() { - let peer_addr = IpAddr::from([127, 0, 0, 1]); - assert!(check_peer_addr_is_in_list(&peer_addr, &vec![])); - assert!(check_peer_addr_is_in_list( - &peer_addr, - &vec!["127.0.0.1".into()] - )); - assert!(!check_peer_addr_is_in_list( - &peer_addr, - &vec!["8.8.8.8".into()] - )); + fn check(v: serde_json::Value) -> bool { + let peer_addr = IpAddr::from([127, 0, 0, 1]); + let ip_list: Vec = serde_json::from_value(v).unwrap(); + check_peer_addr_is_in_list(&peer_addr, &ip_list) + } + + assert!(check(json!([]))); + assert!(check(json!(["127.0.0.1"]))); + assert!(!check(json!(["8.8.8.8"]))); // If there is an incorrect address, it will be skipped. - assert!(check_peer_addr_is_in_list( - &peer_addr, - &vec!["88.8.8".into(), "127.0.0.1".into()] - )); + assert!(check(json!(["88.8.8", "127.0.0.1"]))); } #[test] fn test_parse_ip_v4() -> anyhow::Result<()> { diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index b04208556e..fa3d5d0e31 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -11,7 +11,7 @@ use smol_str::SmolStr; use tokio::time::Instant; use tracing::{debug, info}; -use crate::{config::ProjectInfoCacheOptions, console::AuthSecret}; +use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret}; use super::{Cache, Cached}; @@ -45,7 +45,7 @@ impl From for Entry { #[derive(Default)] struct EndpointInfo { secret: std::collections::HashMap>>, - allowed_ips: Option>>>, + allowed_ips: Option>>>, } impl EndpointInfo { @@ -76,7 +76,7 @@ impl EndpointInfo { &self, valid_since: Instant, ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { + ) -> Option<(Arc>, bool)> { if let Some(allowed_ips) = &self.allowed_ips { if valid_since < allowed_ips.created_at { return Some(( @@ -189,7 +189,7 @@ impl ProjectInfoCacheImpl { pub fn get_allowed_ips( &self, endpoint_id: &SmolStr, - ) -> Option>>> { + ) -> Option>>> { let (valid_since, ignore_cache_since) = self.get_cache_times(); let endpoint_info = self.cache.get(endpoint_id)?; let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); @@ -224,7 +224,7 @@ impl ProjectInfoCacheImpl { &self, project_id: &SmolStr, endpoint_id: &SmolStr, - allowed_ips: Arc>, + allowed_ips: Arc>, ) { if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. @@ -369,7 +369,10 @@ mod tests { [1; 32], ))); let secret2 = None; - let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + let allowed_ips = Arc::new(vec![ + "127.0.0.1".parse().unwrap(), + "127.0.0.2".parse().unwrap(), + ]); cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone()); cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone()); @@ -427,7 +430,10 @@ mod tests { user2.as_str(), [2; 32], ))); - let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + let allowed_ips = Arc::new(vec![ + "127.0.0.1".parse().unwrap(), + "127.0.0.2".parse().unwrap(), + ]); cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone()); cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone()); @@ -479,7 +485,10 @@ mod tests { user2.as_str(), [2; 32], ))); - let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + let allowed_ips = Arc::new(vec![ + "127.0.0.1".parse().unwrap(), + "127.0.0.2".parse().unwrap(), + ]); cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); cache.clone().disable_ttl(); tokio::time::advance(Duration::from_millis(100)).await; diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index c02d65668f..1cfa2d6192 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -2,6 +2,8 @@ use serde::Deserialize; use smol_str::SmolStr; use std::fmt; +use crate::auth::IpPattern; + /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize)] @@ -14,7 +16,7 @@ pub struct ConsoleError { #[derive(Deserialize)] pub struct GetRoleSecret { pub role_secret: Box, - pub allowed_ips: Option>>, + pub allowed_ips: Option>, pub project_id: Option>, } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index bbcddae86c..53c394f52f 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,7 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::backend::ComputeUserInfo, + auth::{backend::ComputeUserInfo, IpPattern}, cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, config::{CacheOptions, ProjectInfoCacheOptions}, @@ -212,7 +212,7 @@ pub enum AuthSecret { pub struct AuthInfo { pub secret: Option, /// List of IP addresses allowed for the autorization. - pub allowed_ips: Vec, + pub allowed_ips: Vec, /// Project ID. This is used for cache invalidation. pub project_id: Option, } @@ -236,7 +236,7 @@ pub struct NodeInfo { pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 49db96a613..55f395a403 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -4,14 +4,13 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo, }; -use crate::cache::Cached; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; use crate::context::RequestMonitoring; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; +use crate::{auth::IpPattern, cache::Cached}; use async_trait::async_trait; use futures::TryFutureExt; -use smol_str::SmolStr; -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use thiserror::Error; use tokio_postgres::{config::SslMode, Client}; use tracing::{error, info, info_span, warn, Instrument}; @@ -88,7 +87,9 @@ impl Api { { Some(s) => { info!("got allowed_ips: {s}"); - s.split(',').map(String::from).collect() + s.split(',') + .map(|s| IpPattern::from_str(s).unwrap()) + .collect() } None => vec![], }; @@ -100,7 +101,7 @@ impl Api { .await?; Ok(AuthInfo { secret, - allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(), + allowed_ips, project_id: None, }) } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 5cee86a9b6..6574e079d5 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -14,7 +14,6 @@ use crate::{ }; use async_trait::async_trait; use futures::TryFutureExt; -use itertools::Itertools; use smol_str::SmolStr; use std::sync::Arc; use tokio::time::Instant; @@ -94,12 +93,7 @@ impl Api { .ok_or(GetAuthInfoError::BadSecret)?; Some(secret) }; - let allowed_ips = body - .allowed_ips - .into_iter() - .flatten() - .map(SmolStr::from) - .collect_vec(); + let allowed_ips = body.allowed_ips.unwrap_or_default(); ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64); Ok(AuthInfo { secret, diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 73fde2d7d0..a552a857b9 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -6,13 +6,13 @@ use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; use crate::auth::backend::{ComputeUserInfo, TestBackend}; +use crate::auth::IpPattern; use crate::config::CertResolver; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; use rstest::rstest; -use smol_str::SmolStr; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream}; @@ -471,7 +471,7 @@ impl TestBackend for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError> { + fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError> { unimplemented!("not used in tests") } }