mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 22:12:56 +00:00
eager parsing of ip addr (#6446)
## Problem Parsing the IP address at check time is a little wasteful. ## Summary of changes Parse the IP when we get it from cplane. Adding a `None` variant to still allow malformed patterns
This commit is contained in:
@@ -4,7 +4,9 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint};
|
||||
pub use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
|
||||
@@ -35,6 +35,8 @@ use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::IpPattern;
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
@@ -55,7 +57,7 @@ pub enum BackendType<'a, T> {
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<IpPattern>, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::{
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashSet, net::IpAddr};
|
||||
use std::{collections::HashSet, net::IpAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
@@ -151,30 +151,51 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<SmolStr>) -> bool {
|
||||
if ip_list.is_empty() {
|
||||
return true;
|
||||
}
|
||||
for ip in ip_list {
|
||||
// We expect that all ip addresses from control plane are correct.
|
||||
// However, if some of them are broken, we still can check the others.
|
||||
match parse_ip_pattern(ip) {
|
||||
Ok(pattern) => {
|
||||
if check_ip(peer_addr, &pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
|
||||
}
|
||||
}
|
||||
false
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
enum IpPattern {
|
||||
pub enum IpPattern {
|
||||
Subnet(ipnet::IpNet),
|
||||
Range(IpAddr, IpAddr),
|
||||
Single(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for IpPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct StrVisitor;
|
||||
impl<'de> serde::de::Visitor<'de> for StrVisitor {
|
||||
type Value = IpPattern;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
|
||||
warn!("Cannot parse ip pattern {v}: {e}");
|
||||
IpPattern::None
|
||||
}))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_str(StrVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for IpPattern {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
parse_ip_pattern(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
|
||||
@@ -196,6 +217,7 @@ fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
|
||||
IpPattern::Subnet(subnet) => subnet.contains(ip),
|
||||
IpPattern::Range(start, end) => start <= ip && ip <= end,
|
||||
IpPattern::Single(addr) => addr == ip,
|
||||
IpPattern::None => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,6 +228,7 @@ fn project_name_valid(name: &str) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
#[test]
|
||||
@@ -415,21 +438,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_check_peer_addr_is_in_list() {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["127.0.0.1".into()]
|
||||
));
|
||||
assert!(!check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["8.8.8.8".into()]
|
||||
));
|
||||
fn check(v: serde_json::Value) -> bool {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
|
||||
check_peer_addr_is_in_list(&peer_addr, &ip_list)
|
||||
}
|
||||
|
||||
assert!(check(json!([])));
|
||||
assert!(check(json!(["127.0.0.1"])));
|
||||
assert!(!check(json!(["8.8.8.8"])));
|
||||
// If there is an incorrect address, it will be skipped.
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["88.8.8".into(), "127.0.0.1".into()]
|
||||
));
|
||||
assert!(check(json!(["88.8.8", "127.0.0.1"])));
|
||||
}
|
||||
#[test]
|
||||
fn test_parse_ip_v4() -> anyhow::Result<()> {
|
||||
|
||||
25
proxy/src/cache/project_info.rs
vendored
25
proxy/src/cache/project_info.rs
vendored
@@ -11,7 +11,7 @@ use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
@@ -45,7 +45,7 @@ impl<T> From<T> for Entry<T> {
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<SmolStr, Entry<Option<AuthSecret>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<SmolStr>>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
@@ -76,7 +76,7 @@ impl EndpointInfo {
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<SmolStr>>, bool)> {
|
||||
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
@@ -189,7 +189,7 @@ impl ProjectInfoCacheImpl {
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
) -> Option<Cached<&Self, Arc<Vec<SmolStr>>>> {
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
|
||||
@@ -224,7 +224,7 @@ impl ProjectInfoCacheImpl {
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
allowed_ips: Arc<Vec<SmolStr>>,
|
||||
allowed_ips: Arc<Vec<IpPattern>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
@@ -369,7 +369,10 @@ mod tests {
|
||||
[1; 32],
|
||||
)));
|
||||
let secret2 = None;
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
@@ -427,7 +430,10 @@ mod tests {
|
||||
user2.as_str(),
|
||||
[2; 32],
|
||||
)));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
@@ -479,7 +485,10 @@ mod tests {
|
||||
user2.as_str(),
|
||||
[2; 32],
|
||||
)));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_millis(100)).await;
|
||||
|
||||
@@ -2,6 +2,8 @@ use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
use std::fmt;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -14,7 +16,7 @@ pub struct ConsoleError {
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<Box<str>>>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub project_id: Option<Box<str>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
@@ -212,7 +212,7 @@ pub enum AuthSecret {
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<SmolStr>,
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<SmolStr>,
|
||||
}
|
||||
@@ -236,7 +236,7 @@ pub struct NodeInfo {
|
||||
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
|
||||
@@ -4,14 +4,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -88,7 +87,9 @@ impl Api {
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',').map(String::from).collect()
|
||||
s.split(',')
|
||||
.map(|s| IpPattern::from_str(s).unwrap())
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
@@ -100,7 +101,7 @@ impl Api {
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(),
|
||||
allowed_ips,
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use crate::{
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
@@ -94,12 +93,7 @@ impl Api {
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
Some(secret)
|
||||
};
|
||||
let allowed_ips = body
|
||||
.allowed_ips
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(SmolStr::from)
|
||||
.collect_vec();
|
||||
let allowed_ips = body.allowed_ips.unwrap_or_default();
|
||||
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
|
||||
@@ -6,13 +6,13 @@ use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
@@ -471,7 +471,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError> {
|
||||
fn get_allowed_ips(&self) -> Result<Vec<IpPattern>, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user