Compare commits

...

4 Commits

Author SHA1 Message Date
Anna Khanova
9c989692ea Fix config 2023-12-18 15:04:38 +01:00
Anna Khanova
485572bc62 Fix 2023-12-18 14:44:23 +01:00
Anna Khanova
1d9a756859 Remove todo 2023-12-18 14:38:51 +01:00
Anna Khanova
c70f30d2c9 Added cache for get role secret 2023-12-18 14:37:44 +01:00
9 changed files with 55 additions and 28 deletions

View File

@@ -9,7 +9,6 @@ use tokio_postgres::config::AuthKeys;
use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange; use crate::auth::validate_password_and_exchange;
use crate::console::errors::GetAuthInfoError; use crate::console::errors::GetAuthInfoError;
use crate::console::provider::AuthInfo;
use crate::console::AuthSecret; use crate::console::AuthSecret;
use crate::proxy::connect_compute::handle_try_wake; use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after; use crate::proxy::retry::retry_after;
@@ -187,17 +186,13 @@ async fn auth_quirks(
}; };
info!("fetching user's authentication info"); info!("fetching user's authentication info");
// TODO(anna): this will slow down both "hacks" below; we probably need a cache. let allowed_ips = api.get_allowed_ips(extra, &info).await?;
let AuthInfo {
secret,
allowed_ips,
} = api.get_auth_info(extra, &info).await?;
// check allowed list // check allowed list
if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) { if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed()); return Err(auth::AuthError::ip_address_not_allowed());
} }
let secret = secret.unwrap_or_else(|| { let secret = api.get_role_secret(extra, &info).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to // If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps). // prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication. // This mocked secret will never lead to successful authentication.

View File

@@ -6,6 +6,7 @@ use proxy::config::HttpConfig;
use proxy::console; use proxy::console;
use proxy::console::provider::AllowedIpsCache; use proxy::console::provider::AllowedIpsCache;
use proxy::console::provider::NodeInfoCache; use proxy::console::provider::NodeInfoCache;
use proxy::console::provider::RoleSecretCache;
use proxy::http; use proxy::http;
use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateBucketInfo;
@@ -86,7 +87,7 @@ struct ProxyCliArgs {
#[clap(long)] #[clap(long)]
metric_collection_interval: Option<String>, metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable) /// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)] #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String, wake_compute_cache: String,
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable). /// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
#[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)] #[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
@@ -127,8 +128,11 @@ struct ProxyCliArgs {
#[clap(flatten)] #[clap(flatten)]
aimd_config: proxy::rate_limiter::AimdConfig, aimd_config: proxy::rate_limiter::AimdConfig,
/// cache for `allowed_ips` (use `size=0` to disable) /// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)] #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String, allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// disable ip check for http requests. If it is too time consuming, it could be turned off. /// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool, disable_ip_check_for_http: bool,
@@ -266,9 +270,11 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
AuthBackend::Console => { AuthBackend::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?; let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?;
let role_secret_cache_config: CacheOptions = args.role_secret_cache.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}"); info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}");
info!("Using RoleSecretCache (wake_compute) with options={role_secret_cache_config:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches { let caches = Box::leak(Box::new(console::caches::ApiCaches {
node_info: NodeInfoCache::new( node_info: NodeInfoCache::new(
"node_info_cache", "node_info_cache",
@@ -282,6 +288,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
allowed_ips_cache_config.ttl, allowed_ips_cache_config.ttl,
false, false,
), ),
role_secret: RoleSecretCache::new(
"role_secret_cache",
role_secret_cache_config.size,
role_secret_cache_config.ttl,
false,
),
})); }));
let config::WakeComputeLockOptions { let config::WakeComputeLockOptions {

View File

@@ -310,10 +310,10 @@ pub struct CacheOptions {
impl CacheOptions { impl CacheOptions {
/// Default options for [`crate::console::provider::NodeInfoCache`]. /// Default options for [`crate::console::provider::NodeInfoCache`].
pub const DEFAULT_OPTIONS_NODE_INFO: &'static str = "size=4000,ttl=4m"; pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
/// Parse cache options passed via cmdline. /// Parse cache options passed via cmdline.
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`]. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> { fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None; let mut size = None;
let mut ttl = None; let mut ttl = None;

View File

@@ -10,6 +10,7 @@ use crate::{
}; };
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use smol_str::SmolStr;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio::{ use tokio::{
sync::{OwnedSemaphorePermit, Semaphore}, sync::{OwnedSemaphorePermit, Semaphore},
@@ -216,6 +217,7 @@ impl ConsoleReqExtra {
} }
/// Auth secret which is managed by the cloud. /// Auth secret which is managed by the cloud.
#[derive(Clone)]
pub enum AuthSecret { pub enum AuthSecret {
#[cfg(feature = "testing")] #[cfg(feature = "testing")]
/// Md5 hash of user's password. /// Md5 hash of user's password.
@@ -250,18 +252,19 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>; pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<String>>>; pub type AllowedIpsCache = TimedLru<SmolStr, Arc<Vec<String>>>;
pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option<AuthSecret>>;
/// This will allocate per each call, but the http requests alone /// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine. /// already require a few allocations, so it should be fine.
#[async_trait] #[async_trait]
pub trait Api { pub trait Api {
/// Get the client's auth secret for authentication. /// Get the client's auth secret for authentication.
async fn get_auth_info( async fn get_role_secret(
&self, &self,
extra: &ConsoleReqExtra, extra: &ConsoleReqExtra,
creds: &ComputeUserInfo, creds: &ComputeUserInfo,
) -> Result<AuthInfo, errors::GetAuthInfoError>; ) -> Result<Option<AuthSecret>, errors::GetAuthInfoError>;
async fn get_allowed_ips( async fn get_allowed_ips(
&self, &self,
@@ -282,7 +285,9 @@ pub struct ApiCaches {
/// Cache for the `wake_compute` API method. /// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache, pub node_info: NodeInfoCache,
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead. /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<String>>>, pub allowed_ips: AllowedIpsCache,
/// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead.
pub role_secret: RoleSecretCache,
} }
/// Various caches for [`console`](super). /// Various caches for [`console`](super).

View File

@@ -142,12 +142,12 @@ async fn get_execute_postgres_query(
#[async_trait] #[async_trait]
impl super::Api for Api { impl super::Api for Api {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn get_auth_info( async fn get_role_secret(
&self, &self,
_extra: &ConsoleReqExtra, _extra: &ConsoleReqExtra,
creds: &ComputeUserInfo, creds: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> { ) -> Result<Option<AuthSecret>, GetAuthInfoError> {
self.do_get_auth_info(creds).await Ok(self.do_get_auth_info(creds).await?.secret)
} }
async fn get_allowed_ips( async fn get_allowed_ips(

View File

@@ -159,12 +159,24 @@ impl Api {
#[async_trait] #[async_trait]
impl super::Api for Api { impl super::Api for Api {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn get_auth_info( async fn get_role_secret(
&self, &self,
extra: &ConsoleReqExtra, extra: &ConsoleReqExtra,
creds: &ComputeUserInfo, creds: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> { ) -> Result<Option<AuthSecret>, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await let ep = creds.endpoint.clone();
let user = creds.inner.user.clone();
if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
return Ok(role_secret.clone());
}
let auth_info = self.do_get_auth_info(extra, creds).await?;
self.caches
.role_secret
.insert((ep.clone(), user), auth_info.secret.clone());
self.caches
.allowed_ips
.insert(ep, Arc::new(auth_info.allowed_ips));
Ok(auth_info.secret)
} }
async fn get_allowed_ips( async fn get_allowed_ips(
@@ -172,8 +184,7 @@ impl super::Api for Api {
extra: &ConsoleReqExtra, extra: &ConsoleReqExtra,
creds: &ComputeUserInfo, creds: &ComputeUserInfo,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> { ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
let key: &str = &creds.endpoint; if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
ALLOWED_IPS_BY_CACHE_OUTCOME ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"]) .with_label_values(&["hit"])
.inc(); .inc();
@@ -182,10 +193,14 @@ impl super::Api for Api {
ALLOWED_IPS_BY_CACHE_OUTCOME ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["miss"]) .with_label_values(&["miss"])
.inc(); .inc();
let allowed_ips = Arc::new(self.do_get_auth_info(extra, creds).await?.allowed_ips); let auth_info = self.do_get_auth_info(extra, creds).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let ep = creds.endpoint.clone();
let user = creds.inner.user.clone();
self.caches self.caches
.allowed_ips .role_secret
.insert(key.into(), allowed_ips.clone()); .insert((ep.clone(), user), auth_info.secret);
self.caches.allowed_ips.insert(ep, allowed_ips.clone());
Ok(allowed_ips) Ok(allowed_ips)
} }

View File

@@ -6,7 +6,7 @@ pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the [password](super::password::SaltedPassword). /// One of the keys derived from the [password](super::password::SaltedPassword).
/// We use the same structure for all keys, i.e. /// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`. /// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Default, PartialEq, Eq)] #[derive(Clone, Default, PartialEq, Eq)]
#[repr(transparent)] #[repr(transparent)]
pub struct ScramKey { pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN], bytes: [u8; SCRAM_KEY_LEN],

View File

@@ -5,6 +5,7 @@ use super::key::ScramKey;
/// Server secret is produced from [password](super::password::SaltedPassword) /// Server secret is produced from [password](super::password::SaltedPassword)
/// and is used throughout the authentication process. /// and is used throughout the authentication process.
#[derive(Clone)]
pub struct ServerSecret { pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function. /// Number of iterations for `PBKDF2` function.
pub iterations: u32, pub iterations: u32,

View File

@@ -431,7 +431,6 @@ async fn connect_to_compute(
application_name: APP_NAME.to_string(), application_name: APP_NAME.to_string(),
options: console_options, options: console_options,
}; };
// TODO(anna): this is a bit hacky way, consider using console notification listener.
if !config.disable_ip_check_for_http { if !config.disable_ip_check_for_http {
let allowed_ips = backend.get_allowed_ips(&extra).await?; let allowed_ips = backend.get_allowed_ips(&extra).await?;
if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) { if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) {