mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 10:22:56 +00:00
Compare commits
2 Commits
min_inflig
...
scram-for-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86f7396eb7 | ||
|
|
b739ea1f0e |
@@ -4,7 +4,7 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
|
||||
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials, IpPattern};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
|
||||
@@ -32,6 +32,8 @@ use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::IpPattern;
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
@@ -55,7 +57,7 @@ pub enum BackendType<'a, T> {
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
@@ -388,7 +390,7 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
|
||||
@@ -153,44 +153,65 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
if ip_list.is_empty() {
|
||||
return true;
|
||||
}
|
||||
for ip in ip_list {
|
||||
// We expect that all ip addresses from control plane are correct.
|
||||
// However, if some of them are broken, we still can check the others.
|
||||
match parse_ip_pattern(ip) {
|
||||
Ok(pattern) => {
|
||||
if check_ip(peer_addr, &pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
|
||||
if check_ip(peer_addr, ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
enum IpPattern {
|
||||
pub enum IpPattern {
|
||||
Subnet(ipnet::IpNet),
|
||||
Range(IpAddr, IpAddr),
|
||||
Single(IpAddr),
|
||||
}
|
||||
|
||||
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
|
||||
if pattern.contains('/') {
|
||||
let subnet: ipnet::IpNet = pattern.parse()?;
|
||||
return Ok(IpPattern::Subnet(subnet));
|
||||
impl<'de> serde::Deserialize<'de> for IpPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let ip = <&'de str>::deserialize(deserializer)?;
|
||||
Ok(Self::from_str_lossy(ip))
|
||||
}
|
||||
if let Some((start, end)) = pattern.split_once('-') {
|
||||
let start: IpAddr = start.parse()?;
|
||||
let end: IpAddr = end.parse()?;
|
||||
return Ok(IpPattern::Range(start, end));
|
||||
}
|
||||
|
||||
impl std::str::FromStr for IpPattern {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(pattern: &str) -> Result<Self, Self::Err> {
|
||||
if pattern.contains('/') {
|
||||
let subnet: ipnet::IpNet = pattern.parse()?;
|
||||
return Ok(IpPattern::Subnet(subnet));
|
||||
}
|
||||
if let Some((start, end)) = pattern.split_once('-') {
|
||||
let start: IpAddr = start.parse()?;
|
||||
let end: IpAddr = end.parse()?;
|
||||
return Ok(IpPattern::Range(start, end));
|
||||
}
|
||||
let addr: IpAddr = pattern.parse()?;
|
||||
Ok(IpPattern::Single(addr))
|
||||
}
|
||||
}
|
||||
|
||||
impl IpPattern {
|
||||
pub fn from_str_lossy(pattern: &str) -> Self {
|
||||
match pattern.parse() {
|
||||
Ok(pattern) => pattern,
|
||||
Err(err) => {
|
||||
warn!("Cannot parse ip: {}; err: {}", pattern, err);
|
||||
// We expect that all ip addresses from control plane are correct.
|
||||
// However, if some of them are broken, we still can check the others.
|
||||
Self::Single([0, 0, 0, 0].into())
|
||||
}
|
||||
}
|
||||
}
|
||||
let addr: IpAddr = pattern.parse()?;
|
||||
Ok(IpPattern::Single(addr))
|
||||
}
|
||||
|
||||
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
|
||||
@@ -213,6 +234,8 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::*;
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
@@ -414,41 +437,47 @@ mod tests {
|
||||
#[test]
|
||||
fn test_check_peer_addr_is_in_list() {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
|
||||
assert!(check_peer_addr_is_in_list(&peer_addr, &[]));
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["127.0.0.1".into()]
|
||||
&[IpPattern::from_str_lossy("127.0.0.1")]
|
||||
));
|
||||
assert!(!check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["8.8.8.8".into()]
|
||||
&[IpPattern::from_str_lossy("8.8.8.8")]
|
||||
));
|
||||
// If there is an incorrect address, it will be skipped.
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["88.8.8".into(), "127.0.0.1".into()]
|
||||
&[
|
||||
IpPattern::from_str_lossy("88.8.8"),
|
||||
IpPattern::from_str_lossy("127.0.0.1")
|
||||
]
|
||||
));
|
||||
}
|
||||
#[test]
|
||||
fn test_parse_ip_v4() -> anyhow::Result<()> {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
// Ok
|
||||
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
|
||||
assert_eq!(
|
||||
parse_ip_pattern("127.0.0.1/31")?,
|
||||
IpPattern::from_str("127.0.0.1")?,
|
||||
IpPattern::Single(peer_addr)
|
||||
);
|
||||
assert_eq!(
|
||||
IpPattern::from_str("127.0.0.1/31")?,
|
||||
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
|
||||
IpPattern::from_str("0.0.0.0-200.0.1.2")?,
|
||||
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
|
||||
);
|
||||
|
||||
// Error
|
||||
assert!(parse_ip_pattern("300.0.1.2").is_err());
|
||||
assert!(parse_ip_pattern("30.1.2").is_err());
|
||||
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
|
||||
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
|
||||
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
|
||||
assert!(IpPattern::from_str("300.0.1.2").is_err());
|
||||
assert!(IpPattern::from_str("30.1.2").is_err());
|
||||
assert!(IpPattern::from_str("127.0.0.1/33").is_err());
|
||||
assert!(IpPattern::from_str("127.0.0.1-127.0.3").is_err());
|
||||
assert!(IpPattern::from_str("1234.0.0.1-127.0.3.0").is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
use std::fmt;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -14,7 +16,7 @@ pub struct ConsoleError {
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<Box<str>>>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
|
||||
@@ -4,12 +4,13 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::{OwnedSemaphorePermit, Semaphore},
|
||||
@@ -229,7 +230,7 @@ pub enum AuthSecret {
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<String>,
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -250,7 +251,7 @@ pub struct NodeInfo {
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<String>>>;
|
||||
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
@@ -267,7 +268,7 @@ pub trait Api {
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
|
||||
) -> Result<Arc<Vec<IpPattern>>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
@@ -282,13 +283,13 @@ pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
|
||||
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<String>>>,
|
||||
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<IpPattern>>>,
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
/// Per-endpoint semaphore
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
|
||||
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
@@ -354,9 +355,9 @@ impl ApiLocks {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_wake_compute_permit(
|
||||
pub async fn get_permit(
|
||||
&self,
|
||||
key: &Arc<str>,
|
||||
key: &SmolStr,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
|
||||
@@ -6,7 +6,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
compute,
|
||||
error::io_error,
|
||||
scram,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
@@ -85,7 +91,7 @@ impl Api {
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',').map(String::from).collect()
|
||||
s.split(',').map(IpPattern::from_str_lossy).collect()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
@@ -154,7 +160,7 @@ impl super::Api for Api {
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
|
||||
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use crate::{
|
||||
auth::IpPattern,
|
||||
proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
@@ -82,12 +84,7 @@ impl Api {
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthSecret::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
let allowed_ips = body
|
||||
.allowed_ips
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(String::from)
|
||||
.collect_vec();
|
||||
let allowed_ips = body.allowed_ips.unwrap_or_default();
|
||||
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret: Some(secret),
|
||||
@@ -171,7 +168,7 @@ impl super::Api for Api {
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
) -> Result<Arc<Vec<IpPattern>>, GetAuthInfoError> {
|
||||
let key: &str = &creds.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
@@ -195,33 +192,31 @@ impl super::Api for Api {
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key: &str = &creds.inner.cache_key;
|
||||
let key = &creds.inner.cache_key;
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
if let Some(cached) = self.caches.node_info.get(&**key) {
|
||||
info!(key = &**key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let key: Arc<str> = key.into();
|
||||
|
||||
let permit = self.locks.get_wake_compute_permit(&key).await?;
|
||||
let permit = self.locks.get_permit(key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
// double check
|
||||
if permit.should_check_cache() {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
if let Some(cached) = self.caches.node_info.get(&**key) {
|
||||
info!(key = &**key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
let (_, cached) = self.caches.node_info.insert(key.as_str().into(), node);
|
||||
info!(key = &**key, "created a cache entry for compute node info");
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ mod mitm;
|
||||
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
@@ -466,7 +467,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<IpPattern>>, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use pbkdf2::{
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, net::IpAddr, sync::Arc};
|
||||
use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
@@ -23,7 +23,7 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
|
||||
console,
|
||||
console::{self, locks::ApiLocks},
|
||||
proxy::{
|
||||
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
|
||||
NUM_DB_CONNECTIONS_OPENED_COUNTER,
|
||||
@@ -114,17 +114,32 @@ pub struct GlobalConnPool {
|
||||
// Using a lock to remove any race conditions.
|
||||
// Eg cleaning up connections while a new connection is returned
|
||||
closed: RwLock<bool>,
|
||||
|
||||
// semaphore guarding unauthenticated postgres connections
|
||||
connection_lock: ApiLocks,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
let connection_lock =
|
||||
ApiLocks::new("http_connect_lock", 2, 32, Duration::from_secs(10)).unwrap();
|
||||
let this = Arc::new(Self {
|
||||
global_pool: DashMap::new(),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
|
||||
proxy_config: config,
|
||||
closed: RwLock::new(false),
|
||||
})
|
||||
connection_lock,
|
||||
});
|
||||
|
||||
let this2 = this.clone();
|
||||
tokio::spawn(async move {
|
||||
this2
|
||||
.connection_lock
|
||||
.garbage_collect_worker(Duration::from_secs(600))
|
||||
.await
|
||||
});
|
||||
this
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
@@ -221,6 +236,11 @@ impl GlobalConnPool {
|
||||
return Ok(Client::new(client, pool).await);
|
||||
}
|
||||
} else {
|
||||
// acquire a permit for un-authenticated access to the compute.
|
||||
// to be clear, postgres will authenticate, but we want to limit the connections
|
||||
// that have potential to be unauthenticated.
|
||||
let _permit = self.connection_lock.get_permit(&conn_info.hostname).await?;
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
connect_to_compute(
|
||||
|
||||
Reference in New Issue
Block a user