Compare commits

..

3 Commits

Author SHA1 Message Date
Anna Khanova
de4449281d Fix 2024-04-10 11:51:10 +02:00
Anna Khanova
2876eaba61 Proper fix 2024-04-10 11:48:44 +02:00
Anna Khanova
5efe95a008 proxy: fix credentials cache lookup (#7349)
## Problem

Incorrect processing of `-pooler` connections.

## Summary of changes

Fix

TODO: add e2e tests for caching
2024-04-10 08:30:09 +00:00
6 changed files with 32 additions and 35 deletions

View File

@@ -403,7 +403,6 @@ async fn main() -> anyhow::Result<()> {
if let auth::BackendType::Console(api, _) = &config.auth_backend {
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
maintenance_tasks.spawn(api.locks.garbage_collect_worker());
if let Some(redis_notifications_client) = redis_notifications_client {
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(
@@ -520,6 +519,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));

View File

@@ -21,7 +21,7 @@ use crate::{
metrics::REDIS_BROKEN_MESSAGES,
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId, Normalize,
EndpointId,
};
#[derive(Deserialize, Debug, Clone)]
@@ -72,9 +72,8 @@ impl EndpointsCache {
!rejected
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
let endpoint = endpoint.normalize();
if endpoint.is_endpoint() {
!self.endpoints.contains(&EndpointIdInt::from(&endpoint))
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
} else if endpoint.is_branch() {
!self
.branches

View File

@@ -16,7 +16,7 @@ use crate::{
config::ProjectInfoCacheOptions,
console::AuthSecret,
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
EndpointCacheKey, EndpointId, RoleName,
EndpointId, RoleName,
};
use super::{Cache, Cached};
@@ -196,7 +196,7 @@ impl ProjectInfoCacheImpl {
}
pub fn get_allowed_ips(
&self,
endpoint_id: &EndpointCacheKey,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();

View File

@@ -543,7 +543,10 @@ impl ApiLocks {
})
}
pub async fn garbage_collect_worker(&self) -> anyhow::Result<Infallible> {
pub async fn garbage_collect_worker(&self) {
if self.permits == 0 {
return;
}
let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
loop {

View File

@@ -59,7 +59,7 @@ impl Api {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint)
.is_valid(ctx, &user_info.endpoint.normalize())
.await
{
info!("endpoint is not valid, skipping the request");
@@ -68,7 +68,7 @@ impl Api {
let request_id = ctx.session_id.to_string();
let application_name = ctx.console_application_name();
async {
let mut request_builder = self
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
@@ -78,14 +78,8 @@ impl Api {
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
]);
let options = user_info.options.to_deep_object();
if !options.is_empty() {
request_builder = request_builder.query(&options);
}
let request = request_builder.build()?;
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
@@ -192,23 +186,27 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let ep = &user_info.endpoint;
let normalized_ep = &user_info.endpoint.normalize();
let user = &user_info.user;
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
if let Some(role_secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
{
return Ok(role_secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
let ep_int = ep.normalize().into();
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
ep_int,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
ep_int,
normalized_ep_int,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
@@ -222,8 +220,8 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let cache_key = user_info.endpoint_cache_key();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(&cache_key) {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
.inc();
@@ -236,16 +234,18 @@ impl super::Api for Api {
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let ep_int = cache_key.normalize().into();
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
ep_int,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches
.project_info
.insert_allowed_ips(project_id, ep_int, allowed_ips.clone());
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
ctx.set_project_id(project_id);
}
Ok((

View File

@@ -5,7 +5,7 @@ use std::{
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
use crate::{BranchId, EndpointCacheKey, EndpointId, ProjectId, RoleName};
use crate::{BranchId, EndpointId, ProjectId, RoleName};
pub trait InternId: Sized + 'static {
fn get_interner() -> &'static StringInterner<Self>;
@@ -165,11 +165,6 @@ impl From<EndpointId> for EndpointIdInt {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
impl From<EndpointCacheKey> for EndpointIdInt {
fn from(value: EndpointCacheKey) -> Self {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BranchIdTag;