mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-19 14:10:37 +00:00
New cache impl
This commit is contained in:
committed by
Vadim Kharitonov
parent
f3769d45ae
commit
9b99d4caa9
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -3104,6 +3104,7 @@ dependencies = [
|
||||
"prometheus",
|
||||
"rand",
|
||||
"rcgen",
|
||||
"ref-cast",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
@@ -3228,6 +3229,26 @@ dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c"
|
||||
dependencies = [
|
||||
"ref-cast-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast-impl"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.7.3"
|
||||
|
||||
@@ -77,6 +77,7 @@ pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
ref-cast = "1.0"
|
||||
regex = "1.4"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
|
||||
|
||||
@@ -35,6 +35,7 @@ postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
ref-cast.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
reqwest-middleware.workspace = true
|
||||
@@ -50,9 +51,9 @@ socket2.workspace = true
|
||||
sync_wrapper.workspace = true
|
||||
thiserror.workspace = true
|
||||
tls-listener.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
|
||||
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
compute::ComputeNode,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
@@ -114,7 +115,7 @@ async fn auth_quirks(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
allow_cleartext: bool,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
allow_cleartext: bool,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
@@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
link::authenticate(url, client)
|
||||
.await?
|
||||
.map(CachedNodeInfo::new_uncached)
|
||||
link::authenticate(url, client).await?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
|
||||
compute::{self, ComputeNode, Password},
|
||||
console::{self, AuthInfo, ConsoleReqExtra},
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
};
|
||||
@@ -14,7 +14,7 @@ pub(super) async fn authenticate(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
info!("fetching user's authentication info");
|
||||
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
@@ -25,7 +25,7 @@ pub(super) async fn authenticate(
|
||||
});
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match info {
|
||||
let keys = match info {
|
||||
AuthInfo::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
@@ -41,21 +41,20 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
Some(compute::ScramKeys {
|
||||
compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
if let Some(keys) = scram_keys {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
node.config.auth_keys(AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ScramKeys(keys),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
},
|
||||
compute::{ComputeNode, Password},
|
||||
console::{self, provider::ConsoleReqExtra},
|
||||
stream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -19,7 +17,7 @@ pub async fn cleartext_hack(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
let password = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword)
|
||||
@@ -27,13 +25,15 @@ pub async fn cleartext_hack(
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
node.config.password(password);
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ClearText(password),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ pub async fn password_hack(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -55,12 +55,14 @@ pub async fn password_hack(
|
||||
info!(project = &payload.endpoint, "received missing parameter");
|
||||
creds.project = Some(payload.endpoint);
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
node.config.password(payload.password);
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ClearText(payload.password),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters,
|
||||
};
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String {
|
||||
pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let span = info_span!("link", psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
let info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
@@ -79,35 +74,8 @@ pub(super) async fn authenticate(
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
config.ssl_mode(SslMode::Require);
|
||||
} else {
|
||||
config.ssl_mode(SslMode::Disable);
|
||||
}
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: true,
|
||||
value: NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux.into(),
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
},
|
||||
value: ComputeNode::Link(info),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use proxy::metrics;
|
||||
use anyhow::bail;
|
||||
use clap::{self, Arg};
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -103,6 +104,7 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
.parse()?;
|
||||
if allow_self_signed_compute {
|
||||
warn!("allowing self-signed compute certificates");
|
||||
proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let metric_collection = match (
|
||||
@@ -154,7 +156,6 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute,
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
|
||||
@@ -1,304 +1,117 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
use std::{any::Any, sync::Arc, time::Instant};
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
/// A variant of LRU where every entry has a TTL.
|
||||
pub mod timed_lru;
|
||||
pub use timed_lru::TimedLru;
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`timed_lru::Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
/// Useful type aliases.
|
||||
pub mod types {
|
||||
pub type Cached<T> = super::Cached<'static, T>;
|
||||
}
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
/// Lookup information for cache entry invalidation.
|
||||
#[derive(Clone)]
|
||||
pub struct LookupInfo<K> {
|
||||
/// Cache entry creation time.
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// This type incapsulates everything needed for cache entry invalidation.
|
||||
/// For convenience, we completely erase the types of a cache ref and a key.
|
||||
/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec).
|
||||
#[derive(Clone)]
|
||||
pub struct InvalidationToken<'a> {
|
||||
// TODO: allow more than one type of references (e.g. Arc) if it's ever needed.
|
||||
cache: &'a (dyn Cache + Sync + Send),
|
||||
info: LookupInfo<Arc<dyn Any + Sync + Send>>,
|
||||
}
|
||||
|
||||
impl InvalidationToken<'_> {
|
||||
/// Invalidate a corresponding cache entry.
|
||||
pub fn invalidate(&self) {
|
||||
let info = LookupInfo {
|
||||
created_at: self.info.created_at,
|
||||
key: self.info.key.as_ref(),
|
||||
};
|
||||
self.cache.invalidate_entry(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// A combination of a cache entry and its invalidation token.
|
||||
/// Makes it easier to see how those two are connected.
|
||||
#[derive(Clone)]
|
||||
pub struct Cached<'a, T> {
|
||||
pub token: Option<InvalidationToken<'a>>,
|
||||
pub value: Arc<T>,
|
||||
}
|
||||
|
||||
impl<T> Cached<'_, T> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: T) -> Self {
|
||||
Self {
|
||||
token: None,
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalidate a corresponding cache entry.
|
||||
pub fn invalidate(&self) {
|
||||
if let Some(token) = &self.token {
|
||||
token.invalidate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Cached<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait captures the notion of cache entry invalidation.
|
||||
/// It doesn't have any associated types because we use dyn-based type erasure.
|
||||
trait Cache {
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
pub use timed_lru::TimedLru;
|
||||
pub mod timed_lru {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
#[test]
|
||||
fn trivial_properties_of_cached() {
|
||||
let cached = Cached::new_uncached(0);
|
||||
assert_eq!(*cached, 0);
|
||||
cached.invalidate();
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
#[test]
|
||||
fn invalidation_token_type_erasure() {
|
||||
let lifetime = std::time::Duration::from_secs(10);
|
||||
let foo = TimedLru::<u32, u32>::new("foo", 128, lifetime);
|
||||
let bar = TimedLru::<String, usize>::new("bar", 128, lifetime);
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
let (_, x) = foo.insert(100.into(), 0.into());
|
||||
let (_, y) = bar.insert(String::new().into(), 404.into());
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
}
|
||||
// Invalidation tokens should be cloneable and homogeneous (same type).
|
||||
let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()];
|
||||
for token in tokens {
|
||||
token.invalidate();
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache> {
|
||||
/// Cache + lookup info.
|
||||
token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: C::Value,
|
||||
}
|
||||
|
||||
impl<C: Cache> Cached<C> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
|
||||
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
|
||||
Self {
|
||||
token: None,
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(&self) {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> Deref for Cached<C> {
|
||||
type Target = C::Value;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> DerefMut for Cached<C> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
// Values are still there.
|
||||
assert_eq!(*x, 0);
|
||||
assert_eq!(*y, 404);
|
||||
}
|
||||
}
|
||||
|
||||
272
proxy/src/cache/timed_lru.rs
vendored
Normal file
272
proxy/src/cache/timed_lru.rs
vendored
Normal file
@@ -0,0 +1,272 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Cache, Cached, InvalidationToken, LookupInfo};
|
||||
use ref_cast::RefCast;
|
||||
use std::{
|
||||
any::Any,
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<Key<K>, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
#[derive(RefCast, Hash, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
struct Query<Q: ?Sized>(Q);
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
struct Key<T: ?Sized>(Arc<T>);
|
||||
|
||||
/// It's impossible to implement this without [`Key`] & [`Query`]:
|
||||
/// * We can't implement std traits for [`Arc`].
|
||||
/// * Even if we could, it'd conflict with `impl<T> Borrow<T> for T`.
|
||||
impl<Q, T> Borrow<Query<Q>> for Key<T>
|
||||
where
|
||||
Q: ?Sized,
|
||||
T: Borrow<Q>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn borrow(&self) -> &Query<Q> {
|
||||
RefCast::ref_cast(self.0.as_ref().borrow())
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: Arc<T>,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of entries in the cache.
|
||||
/// Note that this method will not try to evict stale entries.
|
||||
pub fn size(&self) -> usize {
|
||||
self.cache.lock().len()
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, F, R>(&self, key: &Q, extract: F) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
F: FnOnce(&Arc<K>, &Entry<V>) -> R,
|
||||
{
|
||||
let key: &Query<Q> = RefCast::ref_cast(key);
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(&raw_entry.key().0, entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
?created_at,
|
||||
old_expires_at = ?expires_at,
|
||||
new_expires_at = ?deadline,
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Instant) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(Key(key), entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
?created_at,
|
||||
?expires_at,
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(old, created_at)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenient wrappers for raw methods.
|
||||
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
|
||||
where
|
||||
Self: Sync + Send,
|
||||
{
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<Cached<V>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: Arc::clone(key),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some(Self::invalidation_token(self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Cached<V>) {
|
||||
let (old, created_at) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let info = LookupInfo { created_at, key };
|
||||
let cached = Cached {
|
||||
token: Some(Self::invalidation_token(self, info)),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation details of the entry invalidation machinery.
|
||||
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
|
||||
where
|
||||
Self: Sync + Send,
|
||||
{
|
||||
/// This is a proper (safe) way to create an invalidation token for [`TimedLru`].
|
||||
fn invalidation_token(&self, info: LookupInfo<Arc<K>>) -> InvalidationToken<'_> {
|
||||
InvalidationToken {
|
||||
cache: self,
|
||||
info: LookupInfo {
|
||||
created_at: info.created_at,
|
||||
key: info.key,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This implementation depends on [`Self::invalidation_token`].
|
||||
impl<K: Hash + Eq + 'static, V> Cache for TimedLru<K, V> {
|
||||
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) {
|
||||
let info = LookupInfo {
|
||||
created_at: info.created_at,
|
||||
// NOTE: it's important to downcast to the correct type!
|
||||
key: info.key.downcast_ref::<K>().expect("bad key type"),
|
||||
};
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: LookupInfo<&K>) {
|
||||
let key: &Query<K> = RefCast::ref_cast(info.key);
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
?created_at,
|
||||
?expires_at,
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
}
|
||||
104
proxy/src/cache/timed_lru/tests.rs
vendored
Normal file
104
proxy/src/cache/timed_lru/tests.rs
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
use super::*;
|
||||
|
||||
/// Check that we can define the cache for certain types.
|
||||
#[test]
|
||||
fn definition() {
|
||||
// Check for trivial yet possible types.
|
||||
let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), Default::default());
|
||||
let _ = cache.get(&());
|
||||
|
||||
// Now for something less trivial.
|
||||
let cache = TimedLru::<String, String>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), Default::default());
|
||||
let _ = cache.get(&String::default());
|
||||
let _ = cache.get("str should work");
|
||||
|
||||
// It should also work for non-cloneable values.
|
||||
struct NoClone;
|
||||
let cache = TimedLru::<Box<str>, NoClone>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), NoClone.into());
|
||||
let _ = cache.get(&Box::<str>::from("boxed str"));
|
||||
let _ = cache.get("str should work");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insert() {
|
||||
const CAPACITY: usize = 2;
|
||||
let cache = TimedLru::<String, u32>::new("test", CAPACITY, Duration::from_secs(10));
|
||||
assert_eq!(cache.size(), 0);
|
||||
|
||||
let key = Arc::new(String::from("key"));
|
||||
|
||||
let (old, cached) = cache.insert(key.clone(), 42.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 42);
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let (old, cached) = cache.insert(key, 1.into());
|
||||
assert_eq!(old.as_deref(), Some(&42));
|
||||
assert_eq!(*cached, 1);
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 10);
|
||||
assert_eq!(cache.size(), 2);
|
||||
|
||||
let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 20);
|
||||
assert_eq!(cache.size(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_none() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let cached = cache.get("missing");
|
||||
assert!(matches!(cached, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidation_simple() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let (_, cached) = cache.insert(String::from("key").into(), 100.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
cached.invalidate();
|
||||
|
||||
assert_eq!(cache.size(), 0);
|
||||
assert!(matches!(cache.get("key"), None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidation_preserve_newer() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let key = Arc::new(String::from("key"));
|
||||
|
||||
let (_, cached) = cache.insert(key.clone(), 100.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
let _ = cache.insert(key.clone(), 200.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
cached.invalidate();
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), Some(&200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_expiry() {
|
||||
let lifetime = Duration::from_millis(300);
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, lifetime);
|
||||
|
||||
let key = Arc::new(String::from("key"));
|
||||
let _ = cache.insert(key.clone(), 42.into());
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), Some(&42));
|
||||
|
||||
std::thread::sleep(lifetime);
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), None);
|
||||
}
|
||||
@@ -1,13 +1,28 @@
|
||||
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError};
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::messages::{DatabaseInfo, MetricsAuxInfo},
|
||||
console::CachedNodeInfo,
|
||||
error::UserFacingError,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{io, net::SocketAddr, time::Duration};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
/// Should we allow self-signed certificates in TLS connections?
|
||||
/// Most definitely, this shouldn't be allowed in production.
|
||||
pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError {
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Password {
|
||||
/// A regular cleartext password.
|
||||
ClearText(Vec<u8>),
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
ScramKeys(ScramKeys),
|
||||
}
|
||||
|
||||
pub enum ComputeNode {
|
||||
/// Route via link auth.
|
||||
Link(DatabaseInfo),
|
||||
/// Regular compute node.
|
||||
Static {
|
||||
password: Password,
|
||||
info: CachedNodeInfo,
|
||||
},
|
||||
}
|
||||
|
||||
impl ComputeNode {
|
||||
/// Get metrics auxiliary info.
|
||||
pub fn metrics_aux_info(&self) -> &MetricsAuxInfo {
|
||||
match self {
|
||||
Self::Link(info) => &info.aux,
|
||||
Self::Static { info, .. } => &info.aux,
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalidate compute node info if it's cached.
|
||||
pub fn invalidate(&self) -> bool {
|
||||
if let Self::Static { info, .. } = self {
|
||||
warn!("invalidating compute node info cache entry");
|
||||
info.invalidate();
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Turn compute node info into a postgres connection config.
|
||||
pub fn to_conn_config(&self) -> ConnCfg {
|
||||
let mut config = ConnCfg::new();
|
||||
|
||||
let (host, port) = match self {
|
||||
Self::Link(info) => {
|
||||
// NB: use pre-supplied dbname, user and password for link auth.
|
||||
// See `ConnCfg::set_startup_params` below.
|
||||
config.0.dbname(&info.dbname).user(&info.user);
|
||||
if let Some(password) = &info.password {
|
||||
config.0.password(password.as_bytes());
|
||||
}
|
||||
|
||||
(&info.host, info.port)
|
||||
}
|
||||
Self::Static { info, password } => {
|
||||
// NB: setup auth keys (for SCRAM) or plaintext password.
|
||||
match password {
|
||||
Password::ClearText(text) => config.0.password(text),
|
||||
Password::ScramKeys(keys) => {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned()))
|
||||
}
|
||||
};
|
||||
|
||||
(&info.address.host, info.address.port)
|
||||
}
|
||||
};
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
config.0.ssl_mode(if host.contains("--") {
|
||||
// We need TLS connection with SNI info to properly route it.
|
||||
tokio_postgres::config::SslMode::Require
|
||||
} else {
|
||||
tokio_postgres::config::SslMode::Disable
|
||||
});
|
||||
|
||||
config.0.host(host).port(port);
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
@@ -51,43 +148,32 @@ pub struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub fn new() -> Self {
|
||||
fn new() -> Self {
|
||||
Self(Default::default())
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Link auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
self.user(user);
|
||||
if let (None, Some(user)) = (self.0.get_user(), params.get("user")) {
|
||||
self.0.user(user);
|
||||
}
|
||||
|
||||
// Only set `dbname` if it's not present in the config.
|
||||
// Link auth flow takes dbname from the console's response.
|
||||
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
|
||||
self.dbname(dbname);
|
||||
if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) {
|
||||
self.0.dbname(dbname);
|
||||
}
|
||||
|
||||
// Don't add `options` if they were only used for specifying a project.
|
||||
// Connection pools don't support `options`, because they affect backend startup.
|
||||
if let Some(options) = filtered_options(params) {
|
||||
self.options(&options);
|
||||
self.0.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.application_name(app_name);
|
||||
self.0.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
@@ -95,10 +181,10 @@ impl ConnCfg {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
self.0.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.replication_mode(ReplicationMode::Logical);
|
||||
self.0.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
@@ -113,27 +199,6 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = tokio_postgres::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnCfg {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
@@ -220,16 +285,13 @@ pub struct PostgresConnection {
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
async fn do_connect(
|
||||
&self,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw().await?;
|
||||
|
||||
let tls_connector = native_tls::TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(allow_self_signed_compute)
|
||||
.build()
|
||||
.unwrap();
|
||||
.danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed))
|
||||
.build()?;
|
||||
|
||||
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
|
||||
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
|
||||
|
||||
@@ -261,11 +323,8 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
self.do_connect(allow_self_signed_compute)
|
||||
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
|
||||
self.do_connect()
|
||||
.inspect_err(|err| {
|
||||
// Immediately log the error we have at our disposal.
|
||||
error!("couldn't connect to compute node: {err}");
|
||||
|
||||
@@ -12,7 +12,6 @@ pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents compute node's host & port pair.
|
||||
#[derive(Debug)]
|
||||
pub struct PgEndpoint {
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PgEndpoint {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
let raw = String::deserialize(deserializer)?;
|
||||
let (host, port) = raw
|
||||
.rsplit_once(':')
|
||||
.ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?;
|
||||
|
||||
Ok(PgEndpoint {
|
||||
host: host.into(),
|
||||
port: port.parse().map_err(Error::custom)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
/// Returned by the `/proxy_wake_compute` API method.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WakeCompute {
|
||||
pub address: Box<str>,
|
||||
pub address: PgEndpoint,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
@@ -187,4 +213,24 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wake_compute() -> anyhow::Result<()> {
|
||||
let _: WakeCompute = serde_json::from_value(json!({
|
||||
"address": "127.0.0.1:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
let _: WakeCompute = serde_json::from_value(json!({
|
||||
"address": "[::1]:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
serde_json::from_value::<WakeCompute>(json!({
|
||||
"address": "localhost:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use super::messages::{MetricsAuxInfo, PgEndpoint};
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
cache::{types::Cached, TimedLru},
|
||||
scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
@@ -112,11 +111,9 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
@@ -132,10 +129,7 @@ pub mod errors {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
// API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
@@ -160,23 +154,17 @@ pub enum AuthInfo {
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
/// This struct is cached, so we shouldn't store any user creds here.
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
/// Address of the compute node.
|
||||
pub address: PgEndpoint,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: Arc<MetricsAuxInfo>,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
pub type NodeInfoCache = TimedLru<Box<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<NodeInfo>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, PgEndpoint,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -84,19 +83,15 @@ impl Api {
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
let host = self.endpoint.host_str().unwrap_or("localhost").into();
|
||||
let port = self.endpoint.port().unwrap_or(5432);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
let info = NodeInfo {
|
||||
address: PgEndpoint { host, port },
|
||||
aux: Default::default(),
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
Ok(info)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,10 +5,9 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use crate::{auth::ClientCredentials, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -91,22 +90,9 @@ impl Api {
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux.into(),
|
||||
allow_self_signed_compute: false,
|
||||
address: body.address,
|
||||
aux: body.aux,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
@@ -141,13 +127,15 @@ impl super::Api for Api {
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
info!(key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), node);
|
||||
info!(key = key, "created a cache entry for compute node info");
|
||||
let info = self.do_wake_compute(extra, creds).await?;
|
||||
|
||||
let owned_key = Box::<str>::from(key);
|
||||
let (_, cached) = self.caches.node_info.insert(owned_key.into(), info.into());
|
||||
info!(key, "created a cache entry for compute node info");
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
@@ -177,20 +165,3 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
error!("console responded with an error ({status}): {text}");
|
||||
Err(ApiError::Console { status, text })
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ mod tests;
|
||||
use crate::{
|
||||
auth::{self, backend::AuthSuccess},
|
||||
cancellation::{self, CancelMap},
|
||||
compute::{self, PostgresConnection},
|
||||
compute::{self, ComputeNode, PostgresConnection},
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
error::io_error,
|
||||
@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id, false);
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, true))
|
||||
.await
|
||||
@@ -194,15 +194,7 @@ async fn handle_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let allow_self_signed_compute = config.allow_self_signed_compute;
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
);
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, false))
|
||||
.await
|
||||
@@ -283,61 +275,38 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
// If we couldn't connect, a cached connection info might be to blame
|
||||
// (e.g. the compute node's address might've changed at the wrong time).
|
||||
// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
let invalidate_cache = |_: &compute::ConnectionError| {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
node_info.invalidate();
|
||||
}
|
||||
|
||||
let label = match is_cached {
|
||||
true => "compute_cached",
|
||||
false => "compute_uncached",
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
};
|
||||
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect(allow_self_signed_compute)
|
||||
.inspect_err(invalidate_cache)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn connect_to_compute(
|
||||
node_info: &mut console::CachedNodeInfo,
|
||||
node_info: &mut ComputeNode,
|
||||
params: &StartupMessageParams,
|
||||
extra: &console::ConsoleReqExtra<'_>,
|
||||
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
|
||||
let mut num_retries = NUM_RETRIES_WAKE_COMPUTE;
|
||||
loop {
|
||||
// Apply startup params to the (possibly, cached) compute node info.
|
||||
node_info.config.set_startup_params(params);
|
||||
match connect_to_compute_once(node_info).await {
|
||||
let mut config = node_info.to_conn_config();
|
||||
config.set_startup_params(params);
|
||||
match config.connect().await {
|
||||
Err(e) if num_retries > 0 => {
|
||||
info!("compute node's state has changed; requesting a wake-up");
|
||||
match creds.wake_compute(extra).map_err(io_error).await? {
|
||||
let label = match node_info.invalidate() {
|
||||
true => "compute_cached",
|
||||
false => "compute_uncached",
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
|
||||
let res = creds.wake_compute(extra).map_err(io_error).await?;
|
||||
match (res, &node_info) {
|
||||
// Update `node_info` and try one more time.
|
||||
Some(mut new) => {
|
||||
new.config.reuse_password(&node_info.config);
|
||||
*node_info = new;
|
||||
(Some(new), ComputeNode::Static { password, .. }) => {
|
||||
*node_info = ComputeNode::Static {
|
||||
password: password.to_owned(),
|
||||
info: new,
|
||||
}
|
||||
}
|
||||
// Link auth doesn't work that way, so we just exit.
|
||||
None => return Err(e),
|
||||
_ => return Err(e),
|
||||
}
|
||||
}
|
||||
other => return other,
|
||||
@@ -430,8 +399,6 @@ struct Client<'a, S> {
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
@@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> {
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -468,7 +433,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
mut creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
} = self;
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
@@ -491,19 +455,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
value: mut node_info,
|
||||
} = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(stream, node.stream, &node_info.aux).await
|
||||
|
||||
proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user