parse ippattern eagerly

This commit is contained in:
Conrad Ludgate
2023-12-11 11:52:09 +00:00
parent 6987b5c44e
commit b739ea1f0e
8 changed files with 92 additions and 55 deletions

View File

@@ -4,7 +4,7 @@ pub mod backend;
pub use backend::BackendType;
mod credentials;
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials, IpPattern};
mod password_hack;
pub use password_hack::parse_endpoint_param;

View File

@@ -32,6 +32,8 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use super::IpPattern;
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
@@ -55,7 +57,7 @@ pub enum BackendType<'a, T> {
pub trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
@@ -388,7 +390,7 @@ impl BackendType<'_, ComputeUserInfo> {
pub async fn get_allowed_ips(
&self,
extra: &ConsoleReqExtra,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, creds) => api.get_allowed_ips(extra, creds).await,

View File

@@ -153,44 +153,65 @@ impl ClientCredentials {
}
}
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
if ip_list.is_empty() {
return true;
}
for ip in ip_list {
// We expect that all ip addresses from control plane are correct.
// However, if some of them are broken, we still can check the others.
match parse_ip_pattern(ip) {
Ok(pattern) => {
if check_ip(peer_addr, &pattern) {
return true;
}
}
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
if check_ip(peer_addr, ip) {
return true;
}
}
false
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum IpPattern {
pub enum IpPattern {
Subnet(ipnet::IpNet),
Range(IpAddr, IpAddr),
Single(IpAddr),
}
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
if pattern.contains('/') {
let subnet: ipnet::IpNet = pattern.parse()?;
return Ok(IpPattern::Subnet(subnet));
impl<'de> serde::Deserialize<'de> for IpPattern {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let ip = <&'de str>::deserialize(deserializer)?;
Ok(Self::from_str_lossy(ip))
}
if let Some((start, end)) = pattern.split_once('-') {
let start: IpAddr = start.parse()?;
let end: IpAddr = end.parse()?;
return Ok(IpPattern::Range(start, end));
}
impl std::str::FromStr for IpPattern {
type Err = anyhow::Error;
fn from_str(pattern: &str) -> Result<Self, Self::Err> {
if pattern.contains('/') {
let subnet: ipnet::IpNet = pattern.parse()?;
return Ok(IpPattern::Subnet(subnet));
}
if let Some((start, end)) = pattern.split_once('-') {
let start: IpAddr = start.parse()?;
let end: IpAddr = end.parse()?;
return Ok(IpPattern::Range(start, end));
}
let addr: IpAddr = pattern.parse()?;
Ok(IpPattern::Single(addr))
}
}
impl IpPattern {
pub fn from_str_lossy(pattern: &str) -> Self {
match pattern.parse() {
Ok(pattern) => pattern,
Err(err) => {
warn!("Cannot parse ip: {}; err: {}", pattern, err);
// We expect that all ip addresses from control plane are correct.
// However, if some of them are broken, we still can check the others.
Self::Single([0, 0, 0, 0].into())
}
}
}
let addr: IpAddr = pattern.parse()?;
Ok(IpPattern::Single(addr))
}
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
@@ -213,6 +234,8 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
use ClientCredsParseError::*;
@@ -414,41 +437,47 @@ mod tests {
#[test]
fn test_check_peer_addr_is_in_list() {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
assert!(check_peer_addr_is_in_list(&peer_addr, &[]));
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["127.0.0.1".into()]
&[IpPattern::from_str_lossy("127.0.0.1")]
));
assert!(!check_peer_addr_is_in_list(
&peer_addr,
&vec!["8.8.8.8".into()]
&[IpPattern::from_str_lossy("8.8.8.8")]
));
// If there is an incorrect address, it will be skipped.
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["88.8.8".into(), "127.0.0.1".into()]
&[
IpPattern::from_str_lossy("88.8.8"),
IpPattern::from_str_lossy("127.0.0.1")
]
));
}
#[test]
fn test_parse_ip_v4() -> anyhow::Result<()> {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
// Ok
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
assert_eq!(
parse_ip_pattern("127.0.0.1/31")?,
IpPattern::from_str("127.0.0.1")?,
IpPattern::Single(peer_addr)
);
assert_eq!(
IpPattern::from_str("127.0.0.1/31")?,
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
);
assert_eq!(
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
IpPattern::from_str("0.0.0.0-200.0.1.2")?,
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
);
// Error
assert!(parse_ip_pattern("300.0.1.2").is_err());
assert!(parse_ip_pattern("30.1.2").is_err());
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
assert!(IpPattern::from_str("300.0.1.2").is_err());
assert!(IpPattern::from_str("30.1.2").is_err());
assert!(IpPattern::from_str("127.0.0.1/33").is_err());
assert!(IpPattern::from_str("127.0.0.1-127.0.3").is_err());
assert!(IpPattern::from_str("1234.0.0.1-127.0.3.0").is_err());
Ok(())
}

View File

@@ -2,6 +2,8 @@ use serde::Deserialize;
use smol_str::SmolStr;
use std::fmt;
use crate::auth::IpPattern;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
#[derive(Debug, Deserialize)]
@@ -14,7 +16,7 @@ pub struct ConsoleError {
#[derive(Deserialize)]
pub struct GetRoleSecret {
pub role_secret: Box<str>,
pub allowed_ips: Option<Vec<Box<str>>>,
pub allowed_ips: Option<Vec<IpPattern>>,
}
// Manually implement debug to omit sensitive info.

View File

@@ -4,7 +4,7 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::backend::ComputeUserInfo,
auth::{backend::ComputeUserInfo, IpPattern},
cache::{timed_lru, TimedLru},
compute, scram,
};
@@ -229,7 +229,7 @@ pub enum AuthSecret {
pub struct AuthInfo {
pub secret: Option<AuthSecret>,
/// List of IP addresses allowed for the autorization.
pub allowed_ips: Vec<String>,
pub allowed_ips: Vec<IpPattern>,
}
/// Info for establishing a connection to a compute node.
@@ -250,7 +250,7 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<String>>>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<IpPattern>>>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
@@ -267,7 +267,7 @@ pub trait Api {
&self,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
) -> Result<Arc<Vec<IpPattern>>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
@@ -282,7 +282,7 @@ pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<String>>>,
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<IpPattern>>>,
}
/// Various caches for [`console`](super).

View File

@@ -6,7 +6,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{
auth::{backend::ComputeUserInfo, IpPattern},
compute,
error::io_error,
scram,
url::ApiUrl,
};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
@@ -85,7 +91,7 @@ impl Api {
{
Some(s) => {
info!("got allowed_ips: {s}");
s.split(',').map(String::from).collect()
s.split(',').map(IpPattern::from_str_lossy).collect()
}
None => vec![],
};
@@ -154,7 +160,7 @@ impl super::Api for Api {
&self,
_extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
}

View File

@@ -5,11 +5,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
use crate::{
auth::IpPattern,
proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use async_trait::async_trait;
use futures::TryFutureExt;
use itertools::Itertools;
use std::{net::SocketAddr, sync::Arc};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -82,12 +84,7 @@ impl Api {
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
let allowed_ips = body
.allowed_ips
.into_iter()
.flatten()
.map(String::from)
.collect_vec();
let allowed_ips = body.allowed_ips.unwrap_or_default();
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret: Some(secret),
@@ -171,7 +168,7 @@ impl super::Api for Api {
&self,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
let key: &str = &creds.endpoint;
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
ALLOWED_IPS_BY_CACHE_OUTCOME

View File

@@ -4,6 +4,7 @@ mod mitm;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::auth::IpPattern;
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, http, sasl, scram};
@@ -466,7 +467,7 @@ impl TestBackend for TestConnectMechanism {
}
}
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
}