New cache impl

This commit is contained in:
Dmitry Ivanov
2023-03-15 14:13:41 +03:00
committed by Vadim Kharitonov
parent f3769d45ae
commit 9b99d4caa9
18 changed files with 747 additions and 544 deletions

21
Cargo.lock generated
View File

@@ -3104,6 +3104,7 @@ dependencies = [
"prometheus",
"rand",
"rcgen",
"ref-cast",
"regex",
"reqwest",
"reqwest-middleware",
@@ -3228,6 +3229,26 @@ dependencies = [
"bitflags",
]
[[package]]
name = "ref-cast"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]]
name = "regex"
version = "1.7.3"

View File

@@ -77,6 +77,7 @@ pin-project-lite = "0.2"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
ref-cast = "1.0"
regex = "1.4"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }

View File

@@ -35,6 +35,7 @@ postgres_backend.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true
ref-cast.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
reqwest-middleware.workspace = true
@@ -50,9 +51,9 @@ socket2.workspace = true
sync_wrapper.workspace = true
thiserror.workspace = true
tls-listener.workspace = true
tokio = { workspace = true, features = ["signal"] }
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true

View File

@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
use crate::{
auth::{self, ClientCredentials},
compute::ComputeNode,
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
@@ -114,7 +115,7 @@ async fn auth_quirks(
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
use BackendType::*;
let res = match self {
@@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
Link(url) => {
info!("performing link authentication");
link::authenticate(url, client)
.await?
.map(CachedNodeInfo::new_uncached)
link::authenticate(url, client).await?
}
};

View File

@@ -1,8 +1,8 @@
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
compute::{self, ComputeNode, Password},
console::{self, AuthInfo, ConsoleReqExtra},
sasl, scram,
stream::PqStream,
};
@@ -14,7 +14,7 @@ pub(super) async fn authenticate(
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
@@ -25,7 +25,7 @@ pub(super) async fn authenticate(
});
let flow = AuthFlow::new(client);
let scram_keys = match info {
let keys = match info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
@@ -41,21 +41,20 @@ pub(super) async fn authenticate(
}
};
Some(compute::ScramKeys {
compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
}
};
let mut node = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
let info = api.wake_compute(extra, creds).await?;
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ScramKeys(keys),
info,
},
})
}

View File

@@ -1,10 +1,8 @@
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
},
compute::{ComputeNode, Password},
console::{self, provider::ConsoleReqExtra},
stream,
};
use tokio::io::{AsyncRead, AsyncWrite};
@@ -19,7 +17,7 @@ pub async fn cleartext_hack(
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client)
.begin(auth::CleartextPassword)
@@ -27,13 +25,15 @@ pub async fn cleartext_hack(
.authenticate()
.await?;
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
let info = api.wake_compute(extra, creds).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ClearText(password),
info,
},
})
}
@@ -44,7 +44,7 @@ pub async fn password_hack(
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -55,12 +55,14 @@ pub async fn password_hack(
info!(project = &payload.endpoint, "received missing parameter");
creds.project = Some(payload.endpoint);
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
let info = api.wake_compute(extra, creds).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ClearText(payload.password),
info,
},
})
}

View File

@@ -1,15 +1,10 @@
use super::AuthSuccess;
use crate::{
auth, compute,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters,
};
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
#[derive(Debug, Error)]
@@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String {
pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
let psql_session_id = new_psql_session_id();
let span = info_span!("link", psql_session_id = &psql_session_id);
let span = info_span!("link", psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id);
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
let info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
@@ -79,35 +74,8 @@ pub(super) async fn authenticate(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok(AuthSuccess {
reported_auth_ok: true,
value: NodeInfo {
config,
aux: db_info.aux.into(),
allow_self_signed_compute: false, // caller may override
},
value: ComputeNode::Link(info),
})
}

View File

@@ -6,6 +6,7 @@ use proxy::metrics;
use anyhow::bail;
use clap::{self, Arg};
use proxy::config::{self, ProxyConfig};
use std::sync::atomic::Ordering;
use std::{borrow::Cow, net::SocketAddr};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
@@ -103,6 +104,7 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
.parse()?;
if allow_self_signed_compute {
warn!("allowing self-signed compute certificates");
proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed);
}
let metric_collection = match (
@@ -154,7 +156,6 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute,
}));
Ok(config)

View File

@@ -1,304 +1,117 @@
use std::{
borrow::Borrow,
hash::Hash,
ops::{Deref, DerefMut},
time::{Duration, Instant},
};
use tracing::debug;
use std::{any::Any, sync::Arc, time::Instant};
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// A variant of LRU where every entry has a TTL.
pub mod timed_lru;
pub use timed_lru::TimedLru;
/// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation.
/// This is useful for [`timed_lru::Cached`].
pub trait Cache {
/// Entry's key.
type Key;
/// Useful type aliases.
pub mod types {
pub type Cached<T> = super::Cached<'static, T>;
}
/// Entry's value.
type Value;
/// Lookup information for cache entry invalidation.
#[derive(Clone)]
pub struct LookupInfo<K> {
/// Cache entry creation time.
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Used for entry invalidation.
type LookupInfo<Key>;
/// Search by this key.
key: K,
}
/// This type incapsulates everything needed for cache entry invalidation.
/// For convenience, we completely erase the types of a cache ref and a key.
/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec).
#[derive(Clone)]
pub struct InvalidationToken<'a> {
// TODO: allow more than one type of references (e.g. Arc) if it's ever needed.
cache: &'a (dyn Cache + Sync + Send),
info: LookupInfo<Arc<dyn Any + Sync + Send>>,
}
impl InvalidationToken<'_> {
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
let info = LookupInfo {
created_at: self.info.created_at,
key: self.info.key.as_ref(),
};
self.cache.invalidate_entry(info);
}
}
/// A combination of a cache entry and its invalidation token.
/// Makes it easier to see how those two are connected.
#[derive(Clone)]
pub struct Cached<'a, T> {
pub token: Option<InvalidationToken<'a>>,
pub value: Arc<T>,
}
impl<T> Cached<'_, T> {
/// Place any entry into this wrapper; invalidation will be a no-op.
pub fn new_uncached(value: T) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
if let Some(token) = &self.token {
token.invalidate();
}
}
}
impl<T> std::ops::Deref for Cached<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
/// This trait captures the notion of cache entry invalidation.
/// It doesn't have any associated types because we use dyn-based type erasure.
trait Cache {
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
C::invalidate(self, info)
}
}
pub use timed_lru::TimedLru;
pub mod timed_lru {
#[cfg(test)]
mod tests {
use super::*;
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
#[test]
fn trivial_properties_of_cached() {
let cached = Cached::new_uncached(0);
assert_eq!(*cached, 0);
cached.invalidate();
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
#[test]
fn invalidation_token_type_erasure() {
let lifetime = std::time::Duration::from_secs(10);
let foo = TimedLru::<u32, u32>::new("foo", 128, lifetime);
let bar = TimedLru::<String, usize>::new("bar", 128, lifetime);
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info)
}
}
let (_, x) = foo.insert(100.into(), 0.into());
let (_, y) = bar.insert(String::new().into(), 404.into());
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
// Invalidation tokens should be cloneable and homogeneous (same type).
let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()];
for token in tokens {
token.invalidate();
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
value,
};
(old, cached)
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper.
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
})
}
}
/// Lookup information for key invalidation.
pub struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}
/// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache> {
/// Cache + lookup info.
token: Option<(C, C::LookupInfo<C::Key>)>,
/// The value itself.
pub value: C::Value,
}
impl<C: Cache> Cached<C> {
/// Place any entry into this wrapper; invalidation will be a no-op.
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Drop this entry from a cache if it's still there.
pub fn invalidate(&self) {
if let Some((cache, info)) = &self.token {
cache.invalidate(info);
}
}
/// Tell if this entry is actually cached.
pub fn cached(&self) -> bool {
self.token.is_some()
}
}
impl<C: Cache> Deref for Cached<C> {
type Target = C::Value;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<C: Cache> DerefMut for Cached<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
// Values are still there.
assert_eq!(*x, 0);
assert_eq!(*y, 404);
}
}

272
proxy/src/cache/timed_lru.rs vendored Normal file
View File

@@ -0,0 +1,272 @@
#[cfg(test)]
mod tests;
use super::{Cache, Cached, InvalidationToken, LookupInfo};
use ref_cast::RefCast;
use std::{
any::Any,
borrow::Borrow,
hash::Hash,
sync::Arc,
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<Key<K>, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
}
#[derive(RefCast, Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Query<Q: ?Sized>(Q);
#[derive(Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Key<T: ?Sized>(Arc<T>);
/// It's impossible to implement this without [`Key`] & [`Query`]:
/// * We can't implement std traits for [`Arc`].
/// * Even if we could, it'd conflict with `impl<T> Borrow<T> for T`.
impl<Q, T> Borrow<Query<Q>> for Key<T>
where
Q: ?Sized,
T: Borrow<Q>,
{
#[inline(always)]
fn borrow(&self) -> &Query<Q> {
RefCast::ref_cast(self.0.as_ref().borrow())
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: Arc<T>,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
}
/// Get the number of entries in the cache.
/// Note that this method will not try to evict stale entries.
pub fn size(&self) -> usize {
self.cache.lock().len()
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, F, R>(&self, key: &Q, extract: F) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
F: FnOnce(&Arc<K>, &Entry<V>) -> R,
{
let key: &Query<Q> = RefCast::ref_cast(key);
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(&raw_entry.key().0, entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
?created_at,
old_expires_at = ?expires_at,
new_expires_at = ?deadline,
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Instant) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(Key(key), entry)
.map(|entry| entry.value);
debug!(
?created_at,
?expires_at,
replaced = old.is_some(),
"created a cache entry"
);
(old, created_at)
}
}
/// Convenient wrappers for raw methods.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
pub fn get<Q>(&self, key: &Q) -> Option<Cached<V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: Arc::clone(key),
};
Cached {
token: Some(Self::invalidation_token(self, info)),
value: entry.value.clone(),
}
})
}
pub fn insert(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Cached<V>) {
let (old, created_at) = self.insert_raw(key.clone(), value.clone());
let info = LookupInfo { created_at, key };
let cached = Cached {
token: Some(Self::invalidation_token(self, info)),
value,
};
(old, cached)
}
}
/// Implementation details of the entry invalidation machinery.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
/// This is a proper (safe) way to create an invalidation token for [`TimedLru`].
fn invalidation_token(&self, info: LookupInfo<Arc<K>>) -> InvalidationToken<'_> {
InvalidationToken {
cache: self,
info: LookupInfo {
created_at: info.created_at,
key: info.key,
},
}
}
}
/// This implementation depends on [`Self::invalidation_token`].
impl<K: Hash + Eq + 'static, V> Cache for TimedLru<K, V> {
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) {
let info = LookupInfo {
created_at: info.created_at,
// NOTE: it's important to downcast to the correct type!
key: info.key.downcast_ref::<K>().expect("bad key type"),
};
self.invalidate_raw(info)
}
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: LookupInfo<&K>) {
let key: &Query<K> = RefCast::ref_cast(info.key);
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
?created_at,
?expires_at,
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
}

104
proxy/src/cache/timed_lru/tests.rs vendored Normal file
View File

@@ -0,0 +1,104 @@
use super::*;
/// Check that we can define the cache for certain types.
#[test]
fn definition() {
// Check for trivial yet possible types.
let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&());
// Now for something less trivial.
let cache = TimedLru::<String, String>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&String::default());
let _ = cache.get("str should work");
// It should also work for non-cloneable values.
struct NoClone;
let cache = TimedLru::<Box<str>, NoClone>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), NoClone.into());
let _ = cache.get(&Box::<str>::from("boxed str"));
let _ = cache.get("str should work");
}
#[test]
fn insert() {
const CAPACITY: usize = 2;
let cache = TimedLru::<String, u32>::new("test", CAPACITY, Duration::from_secs(10));
assert_eq!(cache.size(), 0);
let key = Arc::new(String::from("key"));
let (old, cached) = cache.insert(key.clone(), 42.into());
assert_eq!(old, None);
assert_eq!(*cached, 42);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(key, 1.into());
assert_eq!(old.as_deref(), Some(&42));
assert_eq!(*cached, 1);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into());
assert_eq!(old, None);
assert_eq!(*cached, 10);
assert_eq!(cache.size(), 2);
let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into());
assert_eq!(old, None);
assert_eq!(*cached, 20);
assert_eq!(cache.size(), 2);
}
#[test]
fn get_none() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let cached = cache.get("missing");
assert!(matches!(cached, None));
}
#[test]
fn invalidation_simple() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let (_, cached) = cache.insert(String::from("key").into(), 100.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 0);
assert!(matches!(cache.get("key"), None));
}
#[test]
fn invalidation_preserve_newer() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let key = Arc::new(String::from("key"));
let (_, cached) = cache.insert(key.clone(), 100.into());
assert_eq!(cache.size(), 1);
let _ = cache.insert(key.clone(), 200.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 1);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&200));
}
#[test]
fn auto_expiry() {
let lifetime = Duration::from_millis(300);
let cache = TimedLru::<String, u32>::new("test", 2, lifetime);
let key = Arc::new(String::from("key"));
let _ = cache.insert(key.clone(), 42.into());
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&42));
std::thread::sleep(lifetime);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), None);
}

View File

@@ -1,13 +1,28 @@
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError};
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::messages::{DatabaseInfo, MetricsAuxInfo},
console::CachedNodeInfo,
error::UserFacingError,
};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration};
use std::{
io,
net::SocketAddr,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tracing::{error, info, warn};
/// Should we allow self-signed certificates in TLS connections?
/// Most definitely, this shouldn't be allowed in production.
pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false);
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)]
@@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
#[derive(Clone)]
pub enum Password {
/// A regular cleartext password.
ClearText(Vec<u8>),
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
ScramKeys(ScramKeys),
}
pub enum ComputeNode {
/// Route via link auth.
Link(DatabaseInfo),
/// Regular compute node.
Static {
password: Password,
info: CachedNodeInfo,
},
}
impl ComputeNode {
/// Get metrics auxiliary info.
pub fn metrics_aux_info(&self) -> &MetricsAuxInfo {
match self {
Self::Link(info) => &info.aux,
Self::Static { info, .. } => &info.aux,
}
}
/// Invalidate compute node info if it's cached.
pub fn invalidate(&self) -> bool {
if let Self::Static { info, .. } = self {
warn!("invalidating compute node info cache entry");
info.invalidate();
return true;
}
false
}
/// Turn compute node info into a postgres connection config.
pub fn to_conn_config(&self) -> ConnCfg {
let mut config = ConnCfg::new();
let (host, port) = match self {
Self::Link(info) => {
// NB: use pre-supplied dbname, user and password for link auth.
// See `ConnCfg::set_startup_params` below.
config.0.dbname(&info.dbname).user(&info.user);
if let Some(password) = &info.password {
config.0.password(password.as_bytes());
}
(&info.host, info.port)
}
Self::Static { info, password } => {
// NB: setup auth keys (for SCRAM) or plaintext password.
match password {
Password::ClearText(text) => config.0.password(text),
Password::ScramKeys(keys) => {
use tokio_postgres::config::AuthKeys;
config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned()))
}
};
(&info.address.host, info.address.port)
}
};
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
config.0.ssl_mode(if host.contains("--") {
// We need TLS connection with SNI info to properly route it.
tokio_postgres::config::SslMode::Require
} else {
tokio_postgres::config::SslMode::Disable
});
config.0.host(host).port(port);
config
}
}
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
@@ -51,43 +148,32 @@ pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub fn new() -> Self {
fn new() -> Self {
Self(Default::default())
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
/// Apply startup message params to the connection config.
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
self.user(user);
if let (None, Some(user)) = (self.0.get_user(), params.get("user")) {
self.0.user(user);
}
// Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
self.dbname(dbname);
if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) {
self.0.dbname(dbname);
}
// Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(params) {
self.options(&options);
self.0.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.application_name(app_name);
self.0.application_name(app_name);
}
// TODO: This is especially ugly...
@@ -95,10 +181,10 @@ impl ConnCfg {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical);
self.0.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.replication_mode(ReplicationMode::Logical);
self.0.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
@@ -113,27 +199,6 @@ impl ConnCfg {
}
}
impl std::ops::Deref for ConnCfg {
type Target = tokio_postgres::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> {
@@ -220,16 +285,13 @@ pub struct PostgresConnection {
}
impl ConnCfg {
async fn do_connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw().await?;
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
.danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed))
.build()?;
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
@@ -261,11 +323,8 @@ impl ConnCfg {
}
/// Connect to a corresponding compute node.
pub async fn connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
self.do_connect(allow_self_signed_compute)
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
self.do_connect()
.inspect_err(|err| {
// Immediately log the error we have at our disposal.
error!("couldn't connect to compute node: {err}");

View File

@@ -12,7 +12,6 @@ pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
}
#[derive(Debug)]

View File

@@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret {
}
}
/// Represents compute node's host & port pair.
#[derive(Debug)]
pub struct PgEndpoint {
pub host: Box<str>,
pub port: u16,
}
impl<'de> Deserialize<'de> for PgEndpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let raw = String::deserialize(deserializer)?;
let (host, port) = raw
.rsplit_once(':')
.ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?;
Ok(PgEndpoint {
host: host.into(),
port: port.parse().map_err(Error::custom)?,
})
}
}
/// Response which holds compute node's `host:port` pair.
/// Returned by the `/proxy_wake_compute` API method.
#[derive(Debug, Deserialize)]
pub struct WakeCompute {
pub address: Box<str>,
pub address: PgEndpoint,
pub aux: MetricsAuxInfo,
}
@@ -187,4 +213,24 @@ mod tests {
Ok(())
}
#[test]
fn parse_wake_compute() -> anyhow::Result<()> {
let _: WakeCompute = serde_json::from_value(json!({
"address": "127.0.0.1:5432",
"aux": dummy_aux(),
}))?;
let _: WakeCompute = serde_json::from_value(json!({
"address": "[::1]:5432",
"aux": dummy_aux(),
}))?;
serde_json::from_value::<WakeCompute>(json!({
"address": "localhost:5432",
"aux": dummy_aux(),
}))?;
Ok(())
}
}

View File

@@ -1,14 +1,13 @@
pub mod mock;
pub mod neon;
use super::messages::MetricsAuxInfo;
use super::messages::{MetricsAuxInfo, PgEndpoint};
use crate::{
auth::ClientCredentials,
cache::{timed_lru, TimedLru},
compute, scram,
cache::{types::Cached, TimedLru},
scram,
};
use async_trait::async_trait;
use std::sync::Arc;
pub mod errors {
use crate::{
@@ -112,11 +111,9 @@ pub mod errors {
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(Box<str>),
#[error(transparent)]
ApiError(ApiError),
}
@@ -132,10 +129,7 @@ pub mod errors {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
// API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
@@ -160,23 +154,17 @@ pub enum AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
/// This struct is cached, so we shouldn't store any user creds here.
pub struct NodeInfo {
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub config: compute::ConnCfg,
/// Address of the compute node.
pub address: PgEndpoint,
/// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,
pub aux: MetricsAuxInfo,
}
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type NodeInfoCache = TimedLru<Box<str>, NodeInfo>;
pub type CachedNodeInfo = Cached<NodeInfo>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.

View File

@@ -2,13 +2,12 @@
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, PgEndpoint,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Debug, Error)]
@@ -84,19 +83,15 @@ impl Api {
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(SslMode::Disable);
let host = self.endpoint.host_str().unwrap_or("localhost").into();
let port = self.endpoint.port().unwrap_or(5432);
let node = NodeInfo {
config,
let info = NodeInfo {
address: PgEndpoint { host, port },
aux: Default::default(),
allow_self_signed_compute: false,
};
Ok(node)
Ok(info)
}
}

View File

@@ -5,10 +5,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use crate::{auth::ClientCredentials, http, scram};
use async_trait::async_trait;
use futures::TryFutureExt;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Clone)]
@@ -91,22 +90,9 @@ impl Api {
let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo {
config,
aux: body.aux.into(),
allow_self_signed_compute: false,
address: body.address,
aux: body.aux,
};
Ok(node)
@@ -141,13 +127,15 @@ impl super::Api for Api {
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) {
info!(key = key, "found cached compute node info");
info!(key, "found cached compute node info");
return Ok(cached);
}
let node = self.do_wake_compute(extra, creds).await?;
let (_, cached) = self.caches.node_info.insert(key.into(), node);
info!(key = key, "created a cache entry for compute node info");
let info = self.do_wake_compute(extra, creds).await?;
let owned_key = Box::<str>::from(key);
let (_, cached) = self.caches.node_info.insert(owned_key.into(), info.into());
info!(key, "created a cache entry for compute node info");
Ok(cached)
}
@@ -177,20 +165,3 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
}

View File

@@ -4,7 +4,7 @@ mod tests;
use crate::{
auth::{self, backend::AuthSuccess},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
compute::{self, ComputeNode, PostgresConnection},
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id, false);
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
@@ -194,15 +194,7 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let allow_self_signed_compute = config.allow_self_signed_compute;
let client = Client::new(
stream,
creds,
&params,
session_id,
allow_self_signed_compute,
);
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
@@ -283,61 +275,38 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", skip_all)]
async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
) -> Result<PostgresConnection, compute::ConnectionError> {
// If we couldn't connect, a cached connection info might be to blame
// (e.g. the compute node's address might've changed at the wrong time).
// Invalidate the cache entry (if any) to prevent subsequent errors.
let invalidate_cache = |_: &compute::ConnectionError| {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
node_info.invalidate();
}
let label = match is_cached {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
};
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(allow_self_signed_compute)
.inspect_err(invalidate_cache)
.await
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
async fn connect_to_compute(
node_info: &mut console::CachedNodeInfo,
node_info: &mut ComputeNode,
params: &StartupMessageParams,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
let mut num_retries = NUM_RETRIES_WAKE_COMPUTE;
loop {
// Apply startup params to the (possibly, cached) compute node info.
node_info.config.set_startup_params(params);
match connect_to_compute_once(node_info).await {
let mut config = node_info.to_conn_config();
config.set_startup_params(params);
match config.connect().await {
Err(e) if num_retries > 0 => {
info!("compute node's state has changed; requesting a wake-up");
match creds.wake_compute(extra).map_err(io_error).await? {
let label = match node_info.invalidate() {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
let res = creds.wake_compute(extra).map_err(io_error).await?;
match (res, &node_info) {
// Update `node_info` and try one more time.
Some(mut new) => {
new.config.reuse_password(&node_info.config);
*node_info = new;
(Some(new), ComputeNode::Static { password, .. }) => {
*node_info = ComputeNode::Static {
password: password.to_owned(),
info: new,
}
}
// Link auth doesn't work that way, so we just exit.
None => return Err(e),
_ => return Err(e),
}
}
other => return other,
@@ -430,8 +399,6 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
}
impl<'a, S> Client<'a, S> {
@@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
allow_self_signed_compute: bool,
) -> Self {
Self {
stream,
creds,
params,
session_id,
allow_self_signed_compute,
}
}
}
@@ -468,7 +433,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
mut creds,
params,
session_id,
allow_self_signed_compute,
} = self;
let extra = console::ConsoleReqExtra {
@@ -491,19 +455,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
value: mut node_info,
} = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(stream, node.stream, &node_info.aux).await
proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await
}
}