Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
86f7396eb7 permit for unauthenticated connection attempts 2023-12-14 10:52:20 +00:00
Conrad Ludgate
b739ea1f0e parse ippattern eagerly 2023-12-14 10:52:20 +00:00
9 changed files with 129 additions and 73 deletions

View File

@@ -4,7 +4,7 @@ pub mod backend;
pub use backend::BackendType;
mod credentials;
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials, IpPattern};
mod password_hack;
pub use password_hack::parse_endpoint_param;

View File

@@ -32,6 +32,8 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use super::IpPattern;
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
@@ -55,7 +57,7 @@ pub enum BackendType<'a, T> {
pub trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
@@ -388,7 +390,7 @@ impl BackendType<'_, ComputeUserInfo> {
pub async fn get_allowed_ips(
&self,
extra: &ConsoleReqExtra,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, creds) => api.get_allowed_ips(extra, creds).await,

View File

@@ -153,44 +153,65 @@ impl ClientCredentials {
}
}
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
if ip_list.is_empty() {
return true;
}
for ip in ip_list {
// We expect that all ip addresses from control plane are correct.
// However, if some of them are broken, we still can check the others.
match parse_ip_pattern(ip) {
Ok(pattern) => {
if check_ip(peer_addr, &pattern) {
return true;
}
}
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
if check_ip(peer_addr, ip) {
return true;
}
}
false
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum IpPattern {
pub enum IpPattern {
Subnet(ipnet::IpNet),
Range(IpAddr, IpAddr),
Single(IpAddr),
}
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
if pattern.contains('/') {
let subnet: ipnet::IpNet = pattern.parse()?;
return Ok(IpPattern::Subnet(subnet));
impl<'de> serde::Deserialize<'de> for IpPattern {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let ip = <&'de str>::deserialize(deserializer)?;
Ok(Self::from_str_lossy(ip))
}
if let Some((start, end)) = pattern.split_once('-') {
let start: IpAddr = start.parse()?;
let end: IpAddr = end.parse()?;
return Ok(IpPattern::Range(start, end));
}
impl std::str::FromStr for IpPattern {
type Err = anyhow::Error;
fn from_str(pattern: &str) -> Result<Self, Self::Err> {
if pattern.contains('/') {
let subnet: ipnet::IpNet = pattern.parse()?;
return Ok(IpPattern::Subnet(subnet));
}
if let Some((start, end)) = pattern.split_once('-') {
let start: IpAddr = start.parse()?;
let end: IpAddr = end.parse()?;
return Ok(IpPattern::Range(start, end));
}
let addr: IpAddr = pattern.parse()?;
Ok(IpPattern::Single(addr))
}
}
impl IpPattern {
pub fn from_str_lossy(pattern: &str) -> Self {
match pattern.parse() {
Ok(pattern) => pattern,
Err(err) => {
warn!("Cannot parse ip: {}; err: {}", pattern, err);
// We expect that all ip addresses from control plane are correct.
// However, if some of them are broken, we still can check the others.
Self::Single([0, 0, 0, 0].into())
}
}
}
let addr: IpAddr = pattern.parse()?;
Ok(IpPattern::Single(addr))
}
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
@@ -213,6 +234,8 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
use ClientCredsParseError::*;
@@ -414,41 +437,47 @@ mod tests {
#[test]
fn test_check_peer_addr_is_in_list() {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
assert!(check_peer_addr_is_in_list(&peer_addr, &[]));
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["127.0.0.1".into()]
&[IpPattern::from_str_lossy("127.0.0.1")]
));
assert!(!check_peer_addr_is_in_list(
&peer_addr,
&vec!["8.8.8.8".into()]
&[IpPattern::from_str_lossy("8.8.8.8")]
));
// If there is an incorrect address, it will be skipped.
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["88.8.8".into(), "127.0.0.1".into()]
&[
IpPattern::from_str_lossy("88.8.8"),
IpPattern::from_str_lossy("127.0.0.1")
]
));
}
#[test]
fn test_parse_ip_v4() -> anyhow::Result<()> {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
// Ok
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
assert_eq!(
parse_ip_pattern("127.0.0.1/31")?,
IpPattern::from_str("127.0.0.1")?,
IpPattern::Single(peer_addr)
);
assert_eq!(
IpPattern::from_str("127.0.0.1/31")?,
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
);
assert_eq!(
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
IpPattern::from_str("0.0.0.0-200.0.1.2")?,
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
);
// Error
assert!(parse_ip_pattern("300.0.1.2").is_err());
assert!(parse_ip_pattern("30.1.2").is_err());
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
assert!(IpPattern::from_str("300.0.1.2").is_err());
assert!(IpPattern::from_str("30.1.2").is_err());
assert!(IpPattern::from_str("127.0.0.1/33").is_err());
assert!(IpPattern::from_str("127.0.0.1-127.0.3").is_err());
assert!(IpPattern::from_str("1234.0.0.1-127.0.3.0").is_err());
Ok(())
}

View File

@@ -2,6 +2,8 @@ use serde::Deserialize;
use smol_str::SmolStr;
use std::fmt;
use crate::auth::IpPattern;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
#[derive(Debug, Deserialize)]
@@ -14,7 +16,7 @@ pub struct ConsoleError {
#[derive(Deserialize)]
pub struct GetRoleSecret {
pub role_secret: Box<str>,
pub allowed_ips: Option<Vec<Box<str>>>,
pub allowed_ips: Option<Vec<IpPattern>>,
}
// Manually implement debug to omit sensitive info.

View File

@@ -4,12 +4,13 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::backend::ComputeUserInfo,
auth::{backend::ComputeUserInfo, IpPattern},
cache::{timed_lru, TimedLru},
compute, scram,
};
use async_trait::async_trait;
use dashmap::DashMap;
use smol_str::SmolStr;
use std::{sync::Arc, time::Duration};
use tokio::{
sync::{OwnedSemaphorePermit, Semaphore},
@@ -229,7 +230,7 @@ pub enum AuthSecret {
pub struct AuthInfo {
pub secret: Option<AuthSecret>,
/// List of IP addresses allowed for the autorization.
pub allowed_ips: Vec<String>,
pub allowed_ips: Vec<IpPattern>,
}
/// Info for establishing a connection to a compute node.
@@ -250,7 +251,7 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<String>>>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<IpPattern>>>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
@@ -267,7 +268,7 @@ pub trait Api {
&self,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
) -> Result<Arc<Vec<IpPattern>>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
@@ -282,13 +283,13 @@ pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<String>>>,
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<IpPattern>>>,
}
/// Various caches for [`console`](super).
/// Per-endpoint semaphore
pub struct ApiLocks {
name: &'static str,
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
registered: prometheus::IntCounter,
@@ -354,9 +355,9 @@ impl ApiLocks {
})
}
pub async fn get_wake_compute_permit(
pub async fn get_permit(
&self,
key: &Arc<str>,
key: &SmolStr,
) -> Result<WakeComputePermit, errors::WakeComputeError> {
if self.permits == 0 {
return Ok(WakeComputePermit { permit: None });

View File

@@ -6,7 +6,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{
auth::{backend::ComputeUserInfo, IpPattern},
compute,
error::io_error,
scram,
url::ApiUrl,
};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
@@ -85,7 +91,7 @@ impl Api {
{
Some(s) => {
info!("got allowed_ips: {s}");
s.split(',').map(String::from).collect()
s.split(',').map(IpPattern::from_str_lossy).collect()
}
None => vec![],
};
@@ -154,7 +160,7 @@ impl super::Api for Api {
&self,
_extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
}

View File

@@ -5,11 +5,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
use crate::{
auth::IpPattern,
proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use async_trait::async_trait;
use futures::TryFutureExt;
use itertools::Itertools;
use std::{net::SocketAddr, sync::Arc};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -82,12 +84,7 @@ impl Api {
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
let allowed_ips = body
.allowed_ips
.into_iter()
.flatten()
.map(String::from)
.collect_vec();
let allowed_ips = body.allowed_ips.unwrap_or_default();
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret: Some(secret),
@@ -171,7 +168,7 @@ impl super::Api for Api {
&self,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
let key: &str = &creds.endpoint;
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
ALLOWED_IPS_BY_CACHE_OUTCOME
@@ -195,33 +192,31 @@ impl super::Api for Api {
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key: &str = &creds.inner.cache_key;
let key = &creds.inner.cache_key;
// Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) {
info!(key = key, "found cached compute node info");
if let Some(cached) = self.caches.node_info.get(&**key) {
info!(key = &**key, "found cached compute node info");
return Ok(cached);
}
let key: Arc<str> = key.into();
let permit = self.locks.get_wake_compute_permit(&key).await?;
let permit = self.locks.get_permit(key).await?;
// after getting back a permit - it's possible the cache was filled
// double check
if permit.should_check_cache() {
if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = &*key, "found cached compute node info");
if let Some(cached) = self.caches.node_info.get(&**key) {
info!(key = &**key, "found cached compute node info");
return Ok(cached);
}
}
let node = self.do_wake_compute(extra, creds).await?;
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
info!(key = &*key, "created a cache entry for compute node info");
let (_, cached) = self.caches.node_info.insert(key.as_str().into(), node);
info!(key = &**key, "created a cache entry for compute node info");
Ok(cached)
}

View File

@@ -4,6 +4,7 @@ mod mitm;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::auth::IpPattern;
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, http, sasl, scram};
@@ -466,7 +467,7 @@ impl TestBackend for TestConnectMechanism {
}
}
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
}

View File

@@ -9,7 +9,7 @@ use pbkdf2::{
};
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use std::{collections::HashMap, net::IpAddr, sync::Arc};
use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration};
use std::{
fmt,
task::{ready, Poll},
@@ -23,7 +23,7 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use crate::{
auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
console,
console::{self, locks::ApiLocks},
proxy::{
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
NUM_DB_CONNECTIONS_OPENED_COUNTER,
@@ -114,17 +114,32 @@ pub struct GlobalConnPool {
// Using a lock to remove any race conditions.
// Eg cleaning up connections while a new connection is returned
closed: RwLock<bool>,
// semaphore guarding unauthenticated postgres connections
connection_lock: ApiLocks,
}
impl GlobalConnPool {
pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
Arc::new(Self {
let connection_lock =
ApiLocks::new("http_connect_lock", 2, 32, Duration::from_secs(10)).unwrap();
let this = Arc::new(Self {
global_pool: DashMap::new(),
global_pool_size: AtomicUsize::new(0),
max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
proxy_config: config,
closed: RwLock::new(false),
})
connection_lock,
});
let this2 = this.clone();
tokio::spawn(async move {
this2
.connection_lock
.garbage_collect_worker(Duration::from_secs(600))
.await
});
this
}
pub fn shutdown(&self) {
@@ -221,6 +236,11 @@ impl GlobalConnPool {
return Ok(Client::new(client, pool).await);
}
} else {
// acquire a permit for un-authenticated access to the compute.
// to be clear, postgres will authenticate, but we want to limit the connections
// that have potential to be unauthenticated.
let _permit = self.connection_lock.get_permit(&conn_info.hostname).await?;
let conn_id = uuid::Uuid::new_v4();
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
connect_to_compute(