mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-19 03:12:55 +00:00
Compare commits
7 Commits
RemoteExte
...
proxy-role
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4283e4c19e | ||
|
|
565d2b2494 | ||
|
|
d5db782f26 | ||
|
|
2f0e3428bb | ||
|
|
3928d3bc8a | ||
|
|
e7d36cc41f | ||
|
|
5e150c9376 |
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -3973,6 +3973,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"smol_str",
|
||||
"socket2 0.5.5",
|
||||
"sync_wrapper",
|
||||
@@ -5107,6 +5108,9 @@ name = "smallvec"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
|
||||
@@ -81,6 +81,7 @@ postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
smol_str.workspace = true
|
||||
smallvec = { workspace = true, features = ["serde"] }
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint};
|
||||
pub use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
|
||||
@@ -35,6 +35,8 @@ use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::IpPattern;
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
@@ -55,7 +57,7 @@ pub enum BackendType<'a, T> {
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<IpPattern>, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::{
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashSet, net::IpAddr};
|
||||
use std::{collections::HashSet, net::IpAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
@@ -151,30 +151,51 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<SmolStr>) -> bool {
|
||||
if ip_list.is_empty() {
|
||||
return true;
|
||||
}
|
||||
for ip in ip_list {
|
||||
// We expect that all ip addresses from control plane are correct.
|
||||
// However, if some of them are broken, we still can check the others.
|
||||
match parse_ip_pattern(ip) {
|
||||
Ok(pattern) => {
|
||||
if check_ip(peer_addr, &pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
|
||||
}
|
||||
}
|
||||
false
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
enum IpPattern {
|
||||
pub enum IpPattern {
|
||||
Subnet(ipnet::IpNet),
|
||||
Range(IpAddr, IpAddr),
|
||||
Single(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for IpPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct StrVisitor;
|
||||
impl<'de> serde::de::Visitor<'de> for StrVisitor {
|
||||
type Value = IpPattern;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
|
||||
warn!("Cannot parse ip pattern {v}: {e}");
|
||||
IpPattern::None
|
||||
}))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_str(StrVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for IpPattern {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
parse_ip_pattern(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
|
||||
@@ -196,6 +217,7 @@ fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
|
||||
IpPattern::Subnet(subnet) => subnet.contains(ip),
|
||||
IpPattern::Range(start, end) => start <= ip && ip <= end,
|
||||
IpPattern::Single(addr) => addr == ip,
|
||||
IpPattern::None => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,6 +228,7 @@ fn project_name_valid(name: &str) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
#[test]
|
||||
@@ -415,21 +438,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_check_peer_addr_is_in_list() {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["127.0.0.1".into()]
|
||||
));
|
||||
assert!(!check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["8.8.8.8".into()]
|
||||
));
|
||||
fn check(v: serde_json::Value) -> bool {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
|
||||
check_peer_addr_is_in_list(&peer_addr, &ip_list)
|
||||
}
|
||||
|
||||
assert!(check(json!([])));
|
||||
assert!(check(json!(["127.0.0.1"])));
|
||||
assert!(!check(json!(["8.8.8.8"])));
|
||||
// If there is an incorrect address, it will be skipped.
|
||||
assert!(check_peer_addr_is_in_list(
|
||||
&peer_addr,
|
||||
&vec!["88.8.8".into(), "127.0.0.1".into()]
|
||||
));
|
||||
assert!(check(json!(["88.8.8", "127.0.0.1"])));
|
||||
}
|
||||
#[test]
|
||||
fn test_parse_ip_v4() -> anyhow::Result<()> {
|
||||
|
||||
@@ -257,7 +257,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
maintenance_tasks
|
||||
.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
proxy::console::provider::ConsoleBackend::Postgres(_) => {}
|
||||
|
||||
248
proxy/src/cache/project_info.rs
vendored
248
proxy/src/cache/project_info.rs
vendored
@@ -1,17 +1,17 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use hashlink::LruCache;
|
||||
use parking_lot::Mutex;
|
||||
use smallvec::SmallVec;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
@@ -42,56 +42,10 @@ impl<T> From<T> for Entry<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<SmolStr, Entry<AuthSecret>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<SmolStr>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
|
||||
match ignore_cache_since {
|
||||
None => false,
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
role_name: &SmolStr,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(AuthSecret, bool)> {
|
||||
if let Some(secret) = self.secret.get(role_name) {
|
||||
if valid_since < secret.created_at {
|
||||
return Some((
|
||||
secret.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<SmolStr>>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
allowed_ips.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) {
|
||||
self.secret.remove(role_name);
|
||||
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
|
||||
match ignore_cache_since {
|
||||
None => false,
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,12 +57,33 @@ impl EndpointInfo {
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: DashMap<SmolStr, EndpointInfo>,
|
||||
ip_cache: Mutex<LruCache<SmolStr, Entry<Arc<Vec<IpPattern>>>>>,
|
||||
role_cache: Mutex<LruCache<(SmolStr, SmolStr), Entry<AuthSecret>>>,
|
||||
|
||||
project2ep: DashMap<SmolStr, HashSet<SmolStr>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
// endpoints per project:
|
||||
// P90: 1
|
||||
// P99: 2
|
||||
// P995: 3
|
||||
// P999: 10
|
||||
// P9999: 186
|
||||
//
|
||||
// Assuming 1 million projects with this distribution:
|
||||
// (0.9 * 1 + 0.09 * 2 + 0.005 * 3 + 0.004 * 10 + 0.0009 * 186) * 1,000,000
|
||||
// =~ 1,500,000 endpoints
|
||||
//
|
||||
// 1,000,000 * size_of(SmolStr) = 24MB
|
||||
// 1,500,000 * size_of(SmolStr) = 36MB
|
||||
// SmallVec inline overhead: 8B * 0.9 * 1,000,000 = 7.2MB
|
||||
// SmallVec outline overhead: 32B * 0.1 * 1,000,000 = 3.2MB
|
||||
//
|
||||
// Total size: 70.4MB.
|
||||
//
|
||||
// We do not need to prune this hashmap and can safely
|
||||
// keep it in memory up until 100s of millions of projects
|
||||
project2ep: DashMap<SmolStr, SmallVec<[SmolStr; 1]>>,
|
||||
|
||||
start_time: Instant,
|
||||
ttl: Duration,
|
||||
ttl_disabled_since_us: AtomicU64,
|
||||
}
|
||||
|
||||
@@ -121,9 +96,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
self.ip_cache.lock().remove(&endpoint_id);
|
||||
}
|
||||
}
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) {
|
||||
@@ -137,9 +110,9 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
self.role_cache
|
||||
.lock()
|
||||
.remove(&(endpoint_id, role_name.clone()));
|
||||
}
|
||||
}
|
||||
fn enable_ttl(&self) {
|
||||
@@ -148,7 +121,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
}
|
||||
|
||||
fn disable_ttl(&self) {
|
||||
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
|
||||
let new_ttl = (self.start_time.elapsed() + self.ttl).as_micros() as u64;
|
||||
self.ttl_disabled_since_us
|
||||
.store(new_ttl, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
@@ -157,9 +130,10 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
impl ProjectInfoCacheImpl {
|
||||
pub fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
ip_cache: Mutex::new(LruCache::new(config.size)),
|
||||
role_cache: Mutex::new(LruCache::new(config.size * config.max_roles)),
|
||||
project2ep: DashMap::new(),
|
||||
config,
|
||||
ttl: config.ttl,
|
||||
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
|
||||
start_time: Instant::now(),
|
||||
}
|
||||
@@ -171,9 +145,17 @@ impl ProjectInfoCacheImpl {
|
||||
role_name: &SmolStr,
|
||||
) -> Option<Cached<&Self, AuthSecret>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let (value, ignore_cache) =
|
||||
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
|
||||
let (value, ignore_cache) = {
|
||||
let mut cache = self.role_cache.lock();
|
||||
let secret = cache.get(&(endpoint_id.clone(), role_name.clone()))?;
|
||||
if secret.created_at <= valid_since {
|
||||
return None;
|
||||
}
|
||||
(
|
||||
secret.value.clone(),
|
||||
check_ignore_cache(ignore_cache_since, secret.created_at),
|
||||
)
|
||||
};
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
@@ -186,14 +168,23 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
) -> Option<Cached<&Self, Arc<Vec<SmolStr>>>> {
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
let (value, ignore_cache) = {
|
||||
let mut cache = self.ip_cache.lock();
|
||||
let allowed_ips = cache.get(endpoint_id)?;
|
||||
if allowed_ips.created_at <= valid_since {
|
||||
return None;
|
||||
}
|
||||
(
|
||||
allowed_ips.value.clone(),
|
||||
check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
|
||||
)
|
||||
};
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))),
|
||||
@@ -203,6 +194,7 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
|
||||
pub fn insert_role_secret(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
@@ -210,42 +202,33 @@ impl ProjectInfoCacheImpl {
|
||||
role_name: &SmolStr,
|
||||
secret: AuthSecret,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
let mut entry = self.cache.entry(endpoint_id.clone()).or_default();
|
||||
if entry.secret.len() < self.config.max_roles {
|
||||
entry.secret.insert(role_name.clone(), secret.into());
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.role_cache
|
||||
.lock()
|
||||
.insert((endpoint_id.clone(), role_name.clone()), secret.into());
|
||||
}
|
||||
|
||||
pub fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
allowed_ips: Arc<Vec<SmolStr>>,
|
||||
allowed_ips: Arc<Vec<IpPattern>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
self.cache
|
||||
.entry(endpoint_id.clone())
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.ip_cache
|
||||
.lock()
|
||||
.insert(endpoint_id.clone(), allowed_ips.into());
|
||||
}
|
||||
|
||||
fn insert_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
|
||||
self.project2ep
|
||||
.entry(project_id.clone())
|
||||
.or_default()
|
||||
.allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
|
||||
endpoints.insert(endpoint_id.clone());
|
||||
} else {
|
||||
self.project2ep
|
||||
.insert(project_id.clone(), HashSet::from([endpoint_id.clone()]));
|
||||
}
|
||||
.push(endpoint_id.clone());
|
||||
}
|
||||
|
||||
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
|
||||
let mut valid_since = Instant::now() - self.config.ttl;
|
||||
let mut valid_since = Instant::now() - self.ttl;
|
||||
// Only ignore cache if ttl is disabled.
|
||||
let ttl_disabled_since_us = self
|
||||
.ttl_disabled_since_us
|
||||
@@ -260,37 +243,6 @@ impl ProjectInfoCacheImpl {
|
||||
};
|
||||
(valid_since, ignore_cache_since)
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if self.cache.len() < self.config.size {
|
||||
// If there are not too many entries, wait until the next gc cycle.
|
||||
continue;
|
||||
}
|
||||
self.gc();
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self) {
|
||||
let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
|
||||
debug!(shard, "project_info_cache: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut removed = 0;
|
||||
let shard = self.project2ep.shards()[shard].write();
|
||||
for (_, endpoints) in shard.iter() {
|
||||
for endpoint in endpoints.get().iter() {
|
||||
self.cache.remove(endpoint);
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
// We can drop this shard only after making sure that all endpoints are removed.
|
||||
drop(shard);
|
||||
info!("project_info_cache: removed {removed} endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup info for project info cache.
|
||||
@@ -331,14 +283,12 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
|
||||
match &key.lookup_type {
|
||||
LookupType::RoleSecret(role_name) => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
self.role_cache
|
||||
.lock()
|
||||
.remove(&(key.endpoint_id.clone(), role_name.clone()));
|
||||
}
|
||||
LookupType::AllowedIps => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
self.ip_cache.lock().remove(&key.endpoint_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,9 +306,8 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
max_roles: 1,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
@@ -366,7 +315,10 @@ mod tests {
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
@@ -382,7 +334,7 @@ mod tests {
|
||||
let user3: SmolStr = "user3".into();
|
||||
let secret3 = AuthSecret::Scram(ServerSecret::mock(user3.as_str(), [3; 32]));
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none(),);
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(cached.cached());
|
||||
@@ -404,7 +356,6 @@ mod tests {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
@@ -415,7 +366,10 @@ mod tests {
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
@@ -452,7 +406,6 @@ mod tests {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
|
||||
let project_id = "project".into();
|
||||
@@ -461,7 +414,10 @@ mod tests {
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_millis(100)).await;
|
||||
|
||||
@@ -361,14 +361,11 @@ pub struct ProjectInfoCacheOptions {
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
pub max_roles: usize,
|
||||
/// Gc interval.
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=10000,ttl=4m,max_roles=5,gc_interval=60m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
@@ -376,7 +373,7 @@ impl ProjectInfoCacheOptions {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
let mut max_roles = None;
|
||||
let mut gc_interval = None;
|
||||
let mut _gc_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
@@ -387,7 +384,7 @@ impl ProjectInfoCacheOptions {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
"max_roles" => max_roles = Some(value.parse()?),
|
||||
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
|
||||
"gc_interval" => _gc_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
@@ -401,7 +398,6 @@ impl ProjectInfoCacheOptions {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
max_roles: max_roles.context("missing `max_roles`")?,
|
||||
gc_interval: gc_interval.context("missing `gc_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
use std::fmt;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -14,7 +16,7 @@ pub struct ConsoleError {
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<Box<str>>>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub project_id: Option<Box<str>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
@@ -212,7 +212,7 @@ pub enum AuthSecret {
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<SmolStr>,
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<SmolStr>,
|
||||
}
|
||||
@@ -236,7 +236,7 @@ pub struct NodeInfo {
|
||||
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
|
||||
@@ -4,14 +4,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -88,7 +87,9 @@ impl Api {
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',').map(String::from).collect()
|
||||
s.split(',')
|
||||
.map(|s| IpPattern::from_str(s).unwrap())
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
@@ -100,7 +101,7 @@ impl Api {
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(),
|
||||
allowed_ips,
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use crate::{
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
@@ -89,12 +88,7 @@ impl Api {
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthSecret::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
let allowed_ips = body
|
||||
.allowed_ips
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(SmolStr::from)
|
||||
.collect_vec();
|
||||
let allowed_ips = body.allowed_ips.unwrap_or_default();
|
||||
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret: Some(secret),
|
||||
@@ -195,6 +189,7 @@ impl super::Api for Api {
|
||||
Ok(auth_info.secret.map(Cached::new_uncached))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
|
||||
@@ -6,13 +6,13 @@ use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
@@ -471,7 +471,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError> {
|
||||
fn get_allowed_ips(&self) -> Result<Vec<IpPattern>, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["write"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["serde", "write"] }
|
||||
subtle = { version = "2" }
|
||||
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
|
||||
Reference in New Issue
Block a user