From b739ea1f0e248f48dc5e357d975124d24e3690cc Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 11 Dec 2023 11:52:09 +0000 Subject: [PATCH] parse ippattern eagerly --- proxy/src/auth.rs | 2 +- proxy/src/auth/backend.rs | 6 +- proxy/src/auth/credentials.rs | 95 +++++++++++++++++++----------- proxy/src/console/messages.rs | 4 +- proxy/src/console/provider.rs | 10 ++-- proxy/src/console/provider/mock.rs | 12 +++- proxy/src/console/provider/neon.rs | 15 ++--- proxy/src/proxy/tests.rs | 3 +- 8 files changed, 92 insertions(+), 55 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index eadb9abd43..b011c894d2 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -4,7 +4,7 @@ pub mod backend; pub use backend::BackendType; mod credentials; -pub use credentials::{check_peer_addr_is_in_list, ClientCredentials}; +pub use credentials::{check_peer_addr_is_in_list, ClientCredentials, IpPattern}; mod password_hack; pub use password_hack::parse_endpoint_param; diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index ba054b53eb..b0e07df045 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -32,6 +32,8 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; +use super::IpPattern; + /// This type serves two purposes: /// /// * When `T` is `()`, it's just a regular auth backend selector @@ -55,7 +57,7 @@ pub enum BackendType<'a, T> { pub trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError>; + fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError>; } impl std::fmt::Display for BackendType<'_, ()> { @@ -388,7 +390,7 @@ impl BackendType<'_, ComputeUserInfo> { pub async fn get_allowed_ips( &self, extra: &ConsoleReqExtra, - ) -> Result>, GetAuthInfoError> { + ) -> Result>, GetAuthInfoError> { use BackendType::*; match self { Console(api, creds) => api.get_allowed_ips(extra, creds).await, diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 72149e8e29..61ff99d271 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -153,44 +153,65 @@ impl ClientCredentials { } } -pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec) -> bool { +pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { if ip_list.is_empty() { return true; } for ip in ip_list { - // We expect that all ip addresses from control plane are correct. - // However, if some of them are broken, we still can check the others. - match parse_ip_pattern(ip) { - Ok(pattern) => { - if check_ip(peer_addr, &pattern) { - return true; - } - } - Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err), + if check_ip(peer_addr, ip) { + return true; } } false } #[derive(Debug, Clone, Eq, PartialEq)] -enum IpPattern { +pub enum IpPattern { Subnet(ipnet::IpNet), Range(IpAddr, IpAddr), Single(IpAddr), } -fn parse_ip_pattern(pattern: &str) -> anyhow::Result { - if pattern.contains('/') { - let subnet: ipnet::IpNet = pattern.parse()?; - return Ok(IpPattern::Subnet(subnet)); +impl<'de> serde::Deserialize<'de> for IpPattern { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let ip = <&'de str>::deserialize(deserializer)?; + Ok(Self::from_str_lossy(ip)) } - if let Some((start, end)) = pattern.split_once('-') { - let start: IpAddr = start.parse()?; - let end: IpAddr = end.parse()?; - return Ok(IpPattern::Range(start, end)); +} + +impl std::str::FromStr for IpPattern { + type Err = anyhow::Error; + + fn from_str(pattern: &str) -> Result { + if pattern.contains('/') { + let subnet: ipnet::IpNet = pattern.parse()?; + return Ok(IpPattern::Subnet(subnet)); + } + if let Some((start, end)) = pattern.split_once('-') { + let start: IpAddr = start.parse()?; + let end: IpAddr = end.parse()?; + return Ok(IpPattern::Range(start, end)); + } + let addr: IpAddr = pattern.parse()?; + Ok(IpPattern::Single(addr)) + } +} + +impl IpPattern { + pub fn from_str_lossy(pattern: &str) -> Self { + match pattern.parse() { + Ok(pattern) => pattern, + Err(err) => { + warn!("Cannot parse ip: {}; err: {}", pattern, err); + // We expect that all ip addresses from control plane are correct. + // However, if some of them are broken, we still can check the others. + Self::Single([0, 0, 0, 0].into()) + } + } } - let addr: IpAddr = pattern.parse()?; - Ok(IpPattern::Single(addr)) } fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool { @@ -213,6 +234,8 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; use ClientCredsParseError::*; @@ -414,41 +437,47 @@ mod tests { #[test] fn test_check_peer_addr_is_in_list() { let peer_addr = IpAddr::from([127, 0, 0, 1]); - assert!(check_peer_addr_is_in_list(&peer_addr, &vec![])); + assert!(check_peer_addr_is_in_list(&peer_addr, &[])); assert!(check_peer_addr_is_in_list( &peer_addr, - &vec!["127.0.0.1".into()] + &[IpPattern::from_str_lossy("127.0.0.1")] )); assert!(!check_peer_addr_is_in_list( &peer_addr, - &vec!["8.8.8.8".into()] + &[IpPattern::from_str_lossy("8.8.8.8")] )); // If there is an incorrect address, it will be skipped. assert!(check_peer_addr_is_in_list( &peer_addr, - &vec!["88.8.8".into(), "127.0.0.1".into()] + &[ + IpPattern::from_str_lossy("88.8.8"), + IpPattern::from_str_lossy("127.0.0.1") + ] )); } #[test] fn test_parse_ip_v4() -> anyhow::Result<()> { let peer_addr = IpAddr::from([127, 0, 0, 1]); // Ok - assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr)); assert_eq!( - parse_ip_pattern("127.0.0.1/31")?, + IpPattern::from_str("127.0.0.1")?, + IpPattern::Single(peer_addr) + ); + assert_eq!( + IpPattern::from_str("127.0.0.1/31")?, IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?) ); assert_eq!( - parse_ip_pattern("0.0.0.0-200.0.1.2")?, + IpPattern::from_str("0.0.0.0-200.0.1.2")?, IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2])) ); // Error - assert!(parse_ip_pattern("300.0.1.2").is_err()); - assert!(parse_ip_pattern("30.1.2").is_err()); - assert!(parse_ip_pattern("127.0.0.1/33").is_err()); - assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err()); - assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err()); + assert!(IpPattern::from_str("300.0.1.2").is_err()); + assert!(IpPattern::from_str("30.1.2").is_err()); + assert!(IpPattern::from_str("127.0.0.1/33").is_err()); + assert!(IpPattern::from_str("127.0.0.1-127.0.3").is_err()); + assert!(IpPattern::from_str("1234.0.0.1-127.0.3.0").is_err()); Ok(()) } diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 837379b21f..4613490978 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -2,6 +2,8 @@ use serde::Deserialize; use smol_str::SmolStr; use std::fmt; +use crate::auth::IpPattern; + /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize)] @@ -14,7 +16,7 @@ pub struct ConsoleError { #[derive(Deserialize)] pub struct GetRoleSecret { pub role_secret: Box, - pub allowed_ips: Option>>, + pub allowed_ips: Option>, } // Manually implement debug to omit sensitive info. diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index deab966d9e..ebdc20ce41 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,7 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::backend::ComputeUserInfo, + auth::{backend::ComputeUserInfo, IpPattern}, cache::{timed_lru, TimedLru}, compute, scram, }; @@ -229,7 +229,7 @@ pub enum AuthSecret { pub struct AuthInfo { pub secret: Option, /// List of IP addresses allowed for the autorization. - pub allowed_ips: Vec, + pub allowed_ips: Vec, } /// Info for establishing a connection to a compute node. @@ -250,7 +250,7 @@ pub struct NodeInfo { pub type NodeInfoCache = TimedLru, NodeInfo>; pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; -pub type AllowedIpsCache = TimedLru, Arc>>; +pub type AllowedIpsCache = TimedLru, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -267,7 +267,7 @@ pub trait Api { &self, extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result>, errors::GetAuthInfoError>; + ) -> Result>, errors::GetAuthInfoError>; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( @@ -282,7 +282,7 @@ pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub node_info: NodeInfoCache, /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead. - pub allowed_ips: TimedLru, Arc>>, + pub allowed_ips: TimedLru, Arc>>, } /// Various caches for [`console`](super). diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index c464b4daf2..982e0c4741 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -6,7 +6,13 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; +use crate::{ + auth::{backend::ComputeUserInfo, IpPattern}, + compute, + error::io_error, + scram, + url::ApiUrl, +}; use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; @@ -85,7 +91,7 @@ impl Api { { Some(s) => { info!("got allowed_ips: {s}"); - s.split(',').map(String::from).collect() + s.split(',').map(IpPattern::from_str_lossy).collect() } None => vec![], }; @@ -154,7 +160,7 @@ impl super::Api for Api { &self, _extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result>, GetAuthInfoError> { + ) -> Result>, GetAuthInfoError> { Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips)) } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 192252a0df..d5b4e55a04 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -5,11 +5,13 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}; use crate::{auth::backend::ComputeUserInfo, compute, http, scram}; +use crate::{ + auth::IpPattern, + proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, +}; use async_trait::async_trait; use futures::TryFutureExt; -use itertools::Itertools; use std::{net::SocketAddr, sync::Arc}; use tokio::time::Instant; use tokio_postgres::config::SslMode; @@ -82,12 +84,7 @@ impl Api { let secret = scram::ServerSecret::parse(&body.role_secret) .map(AuthSecret::Scram) .ok_or(GetAuthInfoError::BadSecret)?; - let allowed_ips = body - .allowed_ips - .into_iter() - .flatten() - .map(String::from) - .collect_vec(); + let allowed_ips = body.allowed_ips.unwrap_or_default(); ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64); Ok(AuthInfo { secret: Some(secret), @@ -171,7 +168,7 @@ impl super::Api for Api { &self, extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result>, GetAuthInfoError> { + ) -> Result>, GetAuthInfoError> { let key: &str = &creds.endpoint; if let Some(allowed_ips) = self.caches.allowed_ips.get(key) { ALLOWED_IPS_BY_CACHE_OUTCOME diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 4691abbfb9..d0768c6365 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -4,6 +4,7 @@ mod mitm; use super::*; use crate::auth::backend::{ComputeUserInfo, TestBackend}; +use crate::auth::IpPattern; use crate::config::CertResolver; use crate::console::{CachedNodeInfo, NodeInfo}; use crate::{auth, http, sasl, scram}; @@ -466,7 +467,7 @@ impl TestBackend for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError> { + fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError> { unimplemented!("not used in tests") } }