From 9b99d4caa9f442fe52a1e1f44fe0aa6ee4b16b92 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 15 Mar 2023 14:13:41 +0300 Subject: [PATCH] New cache impl --- Cargo.lock | 21 ++ Cargo.toml | 1 + proxy/Cargo.toml | 3 +- proxy/src/auth/backend.rs | 9 +- proxy/src/auth/backend/classic.rs | 23 +- proxy/src/auth/backend/hacks.rs | 26 +- proxy/src/auth/backend/link.rs | 42 +--- proxy/src/bin/proxy.rs | 3 +- proxy/src/cache.rs | 383 ++++++++--------------------- proxy/src/cache/timed_lru.rs | 272 ++++++++++++++++++++ proxy/src/cache/timed_lru/tests.rs | 104 ++++++++ proxy/src/compute.rs | 169 ++++++++----- proxy/src/config.rs | 1 - proxy/src/console/messages.rs | 48 +++- proxy/src/console/provider.rs | 34 +-- proxy/src/console/provider/mock.rs | 19 +- proxy/src/console/provider/neon.rs | 47 +--- proxy/src/proxy.rs | 86 ++----- 18 files changed, 747 insertions(+), 544 deletions(-) create mode 100644 proxy/src/cache/timed_lru.rs create mode 100644 proxy/src/cache/timed_lru/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 2223453a08..ace68dd670 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,6 +3104,7 @@ dependencies = [ "prometheus", "rand", "rcgen", + "ref-cast", "regex", "reqwest", "reqwest-middleware", @@ -3228,6 +3229,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ref-cast" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + [[package]] name = "regex" version = "1.7.3" diff --git a/Cargo.toml b/Cargo.toml index 7895459841..78004b639e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ pin-project-lite = "0.2" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" +ref-cast = "1.0" regex = "1.4" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index e7a4fd236e..d58a2b0ac2 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -35,6 +35,7 @@ postgres_backend.workspace = true pq_proto.workspace = true prometheus.workspace = true rand.workspace = true +ref-cast.workspace = true regex.workspace = true reqwest = { workspace = true, features = ["json"] } reqwest-middleware.workspace = true @@ -50,9 +51,9 @@ socket2.workspace = true sync_wrapper.workspace = true thiserror.workspace = true tls-listener.workspace = true +tokio = { workspace = true, features = ["signal"] } tokio-postgres.workspace = true tokio-rustls.workspace = true -tokio = { workspace = true, features = ["signal"] } tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 9322e4f9ff..7e9dc8c235 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -6,6 +6,7 @@ pub use link::LinkAuthError; use crate::{ auth::{self, ClientCredentials}, + compute::ComputeNode, console::{ self, provider::{CachedNodeInfo, ConsoleReqExtra}, @@ -114,7 +115,7 @@ async fn auth_quirks( creds: &mut ClientCredentials<'_>, client: &mut stream::PqStream, allow_cleartext: bool, -) -> auth::Result> { +) -> auth::Result> { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> { extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, allow_cleartext: bool, - ) -> auth::Result> { + ) -> auth::Result> { use BackendType::*; let res = match self { @@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> { Link(url) => { info!("performing link authentication"); - link::authenticate(url, client) - .await? - .map(CachedNodeInfo::new_uncached) + link::authenticate(url, client).await? } }; diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 6753e7ed7f..f5ba54ae20 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -1,8 +1,8 @@ use super::AuthSuccess; use crate::{ auth::{self, AuthFlow, ClientCredentials}, - compute, - console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra}, + compute::{self, ComputeNode, Password}, + console::{self, AuthInfo, ConsoleReqExtra}, sasl, scram, stream::PqStream, }; @@ -14,7 +14,7 @@ pub(super) async fn authenticate( extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials<'_>, client: &mut PqStream, -) -> auth::Result> { +) -> auth::Result> { info!("fetching user's authentication info"); let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| { // If we don't have an authentication secret, we mock one to @@ -25,7 +25,7 @@ pub(super) async fn authenticate( }); let flow = AuthFlow::new(client); - let scram_keys = match info { + let keys = match info { AuthInfo::Md5(_) => { info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); @@ -41,21 +41,20 @@ pub(super) async fn authenticate( } }; - Some(compute::ScramKeys { + compute::ScramKeys { client_key: client_key.as_bytes(), server_key: secret.server_key.as_bytes(), - }) + } } }; - let mut node = api.wake_compute(extra, creds).await?; - if let Some(keys) = scram_keys { - use tokio_postgres::config::AuthKeys; - node.config.auth_keys(AuthKeys::ScramSha256(keys)); - } + let info = api.wake_compute(extra, creds).await?; Ok(AuthSuccess { reported_auth_ok: false, - value: node, + value: ComputeNode::Static { + password: Password::ScramKeys(keys), + info, + }, }) } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index dcc93ec04c..261a739f7e 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -1,10 +1,8 @@ use super::AuthSuccess; use crate::{ auth::{self, AuthFlow, ClientCredentials}, - console::{ - self, - provider::{CachedNodeInfo, ConsoleReqExtra}, - }, + compute::{ComputeNode, Password}, + console::{self, provider::ConsoleReqExtra}, stream, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -19,7 +17,7 @@ pub async fn cleartext_hack( extra: &ConsoleReqExtra<'_>, creds: &mut ClientCredentials<'_>, client: &mut stream::PqStream, -) -> auth::Result> { +) -> auth::Result> { warn!("cleartext auth flow override is enabled, proceeding"); let password = AuthFlow::new(client) .begin(auth::CleartextPassword) @@ -27,13 +25,15 @@ pub async fn cleartext_hack( .authenticate() .await?; - let mut node = api.wake_compute(extra, creds).await?; - node.config.password(password); + let info = api.wake_compute(extra, creds).await?; // Report tentative success; compute node will check the password anyway. Ok(AuthSuccess { reported_auth_ok: false, - value: node, + value: ComputeNode::Static { + password: Password::ClearText(password), + info, + }, }) } @@ -44,7 +44,7 @@ pub async fn password_hack( extra: &ConsoleReqExtra<'_>, creds: &mut ClientCredentials<'_>, client: &mut stream::PqStream, -) -> auth::Result> { +) -> auth::Result> { warn!("project not specified, resorting to the password hack auth flow"); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) @@ -55,12 +55,14 @@ pub async fn password_hack( info!(project = &payload.endpoint, "received missing parameter"); creds.project = Some(payload.endpoint); - let mut node = api.wake_compute(extra, creds).await?; - node.config.password(payload.password); + let info = api.wake_compute(extra, creds).await?; // Report tentative success; compute node will check the password anyway. Ok(AuthSuccess { reported_auth_ok: false, - value: node, + value: ComputeNode::Static { + password: Password::ClearText(payload.password), + info, + }, }) } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index da43cf11c4..7e56788f6f 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,15 +1,10 @@ use super::AuthSuccess; use crate::{ - auth, compute, - console::{self, provider::NodeInfo}, - error::UserFacingError, - stream::PqStream, - waiters, + auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters, }; use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::config::SslMode; use tracing::{info, info_span}; #[derive(Debug, Error)] @@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String { pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result> { +) -> auth::Result> { let psql_session_id = new_psql_session_id(); - let span = info_span!("link", psql_session_id = &psql_session_id); + let span = info_span!("link", psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); - let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async { + let info = console::mgmt::with_waiter(psql_session_id, |waiter| async { // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); client @@ -79,35 +74,8 @@ pub(super) async fn authenticate( client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; - // This config should be self-contained, because we won't - // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(); - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - // Backwards compatibility. pg_sni_proxy uses "--" in domain names - // while direct connections do not. Once we migrate to pg_sni_proxy - // everywhere, we can remove this. - if db_info.host.contains("--") { - // we need TLS connection with SNI info to properly route it - config.ssl_mode(SslMode::Require); - } else { - config.ssl_mode(SslMode::Disable); - } - - if let Some(password) = db_info.password { - config.password(password.as_ref()); - } - Ok(AuthSuccess { reported_auth_ok: true, - value: NodeInfo { - config, - aux: db_info.aux.into(), - allow_self_signed_compute: false, // caller may override - }, + value: ComputeNode::Link(info), }) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 28e6e25317..db51688612 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -6,6 +6,7 @@ use proxy::metrics; use anyhow::bail; use clap::{self, Arg}; use proxy::config::{self, ProxyConfig}; +use std::sync::atomic::Ordering; use std::{borrow::Cow, net::SocketAddr}; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; @@ -103,6 +104,7 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> .parse()?; if allow_self_signed_compute { warn!("allowing self-signed compute certificates"); + proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed); } let metric_collection = match ( @@ -154,7 +156,6 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> tls_config, auth_backend, metric_collection, - allow_self_signed_compute, })); Ok(config) diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index 4e16cc39ec..4c0cffeb7f 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,304 +1,117 @@ -use std::{ - borrow::Borrow, - hash::Hash, - ops::{Deref, DerefMut}, - time::{Duration, Instant}, -}; -use tracing::debug; +use std::{any::Any, sync::Arc, time::Instant}; -// This seems to make more sense than `lru` or `cached`: -// -// * `near/nearcore` ditched `cached` in favor of `lru` -// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). -// -// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). -// This severely hinders its usage both in terms of creating wrappers and supported key types. -// -// On the other hand, `hashlink` has good download stats and appears to be maintained. -use hashlink::{linked_hash_map::RawEntryMut, LruCache}; +/// A variant of LRU where every entry has a TTL. +pub mod timed_lru; +pub use timed_lru::TimedLru; -/// A generic trait which exposes types of cache's key and value, -/// as well as the notion of cache entry invalidation. -/// This is useful for [`timed_lru::Cached`]. -pub trait Cache { - /// Entry's key. - type Key; +/// Useful type aliases. +pub mod types { + pub type Cached = super::Cached<'static, T>; +} - /// Entry's value. - type Value; +/// Lookup information for cache entry invalidation. +#[derive(Clone)] +pub struct LookupInfo { + /// Cache entry creation time. + /// We use this during invalidation lookups to prevent eviction of a newer + /// entry sharing the same key (it might've been inserted by a different + /// task after we got the entry we're trying to invalidate now). + created_at: Instant, - /// Used for entry invalidation. - type LookupInfo; + /// Search by this key. + key: K, +} +/// This type incapsulates everything needed for cache entry invalidation. +/// For convenience, we completely erase the types of a cache ref and a key. +/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec). +#[derive(Clone)] +pub struct InvalidationToken<'a> { + // TODO: allow more than one type of references (e.g. Arc) if it's ever needed. + cache: &'a (dyn Cache + Sync + Send), + info: LookupInfo>, +} + +impl InvalidationToken<'_> { + /// Invalidate a corresponding cache entry. + pub fn invalidate(&self) { + let info = LookupInfo { + created_at: self.info.created_at, + key: self.info.key.as_ref(), + }; + self.cache.invalidate_entry(info); + } +} + +/// A combination of a cache entry and its invalidation token. +/// Makes it easier to see how those two are connected. +#[derive(Clone)] +pub struct Cached<'a, T> { + pub token: Option>, + pub value: Arc, +} + +impl Cached<'_, T> { + /// Place any entry into this wrapper; invalidation will be a no-op. + pub fn new_uncached(value: T) -> Self { + Self { + token: None, + value: value.into(), + } + } + + /// Invalidate a corresponding cache entry. + pub fn invalidate(&self) { + if let Some(token) = &self.token { + token.invalidate(); + } + } +} + +impl std::ops::Deref for Cached<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +/// This trait captures the notion of cache entry invalidation. +/// It doesn't have any associated types because we use dyn-based type erasure. +trait Cache { /// Invalidate an entry using a lookup info. /// We don't have an empty default impl because it's error-prone. - fn invalidate(&self, _: &Self::LookupInfo); + fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>); } -impl Cache for &C { - type Key = C::Key; - type Value = C::Value; - type LookupInfo = C::LookupInfo; - - fn invalidate(&self, info: &Self::LookupInfo) { - C::invalidate(self, info) - } -} - -pub use timed_lru::TimedLru; -pub mod timed_lru { +#[cfg(test)] +mod tests { use super::*; - /// An implementation of timed LRU cache with fixed capacity. - /// Key properties: - /// - /// * Whenever a new entry is inserted, the least recently accessed one is evicted. - /// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). - /// - /// * When the entry is about to be retrieved, we check its expiration timestamp. - /// If the entry has expired, we remove it from the cache; Otherwise we bump the - /// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong - /// its existence. - /// - /// * There's an API for immediate invalidation (removal) of a cache entry; - /// It's useful in case we know for sure that the entry is no longer correct. - /// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information. - /// - /// * Expired entries are kept in the cache, until they are evicted by the LRU policy, - /// or by a successful lookup (i.e. the entry hasn't expired yet). - /// There is no background job to reap the expired records. - /// - /// * It's possible for an entry that has not yet expired entry to be evicted - /// before expired items. That's a bit wasteful, but probably fine in practice. - pub struct TimedLru { - /// Cache's name for tracing. - name: &'static str, - - /// The underlying cache implementation. - cache: parking_lot::Mutex>>, - - /// Default time-to-live of a single entry. - ttl: Duration, + #[test] + fn trivial_properties_of_cached() { + let cached = Cached::new_uncached(0); + assert_eq!(*cached, 0); + cached.invalidate(); } - impl Cache for TimedLru { - type Key = K; - type Value = V; - type LookupInfo = LookupInfo; + #[test] + fn invalidation_token_type_erasure() { + let lifetime = std::time::Duration::from_secs(10); + let foo = TimedLru::::new("foo", 128, lifetime); + let bar = TimedLru::::new("bar", 128, lifetime); - fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info) - } - } + let (_, x) = foo.insert(100.into(), 0.into()); + let (_, y) = bar.insert(String::new().into(), 404.into()); - struct Entry { - created_at: Instant, - expires_at: Instant, - value: T, - } - - impl TimedLru { - /// Construct a new LRU cache with timed entries. - pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self { - Self { - name, - cache: LruCache::new(capacity).into(), - ttl, - } + // Invalidation tokens should be cloneable and homogeneous (same type). + let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()]; + for token in tokens { + token.invalidate(); } - /// Drop an entry from the cache if it's outdated. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn invalidate_raw(&self, info: &LookupInfo) { - let now = Instant::now(); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let raw_entry = match cache.raw_entry_mut().from_key(&info.key) { - RawEntryMut::Vacant(_) => return, - RawEntryMut::Occupied(x) => x, - }; - - // Remove the entry if it was created prior to lookup timestamp. - let entry = raw_entry.get(); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - let should_remove = created_at <= info.created_at || expires_at <= now; - - if should_remove { - raw_entry.remove(); - } - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - entry_removed = should_remove, - "processed a cache entry invalidation event" - ); - } - - /// Try retrieving an entry by its key, then execute `extract` if it exists. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let now = Instant::now(); - let deadline = now.checked_add(self.ttl).expect("time overflow"); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let mut raw_entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return None, - RawEntryMut::Occupied(x) => x, - }; - - // Immeditely drop the entry if it has expired. - let entry = raw_entry.get(); - if entry.expires_at <= now { - raw_entry.remove(); - return None; - } - - let value = extract(raw_entry.key(), entry); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - - // Update the deadline and the entry's position in the LRU list. - raw_entry.get_mut().expires_at = deadline; - raw_entry.to_back(); - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - old_expires_at = format_args!("{expires_at:?}"), - new_expires_at = format_args!("{deadline:?}"), - "accessed a cache entry" - ); - - Some(value) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { - let created_at = Instant::now(); - let expires_at = created_at.checked_add(self.ttl).expect("time overflow"); - - let entry = Entry { - created_at, - expires_at, - value, - }; - - // Do costly things before taking the lock. - let old = self - .cache - .lock() - .insert(key, entry) - .map(|entry| entry.value); - - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - replaced = old.is_some(), - "created a cache entry" - ); - - (created_at, old) - } - } - - impl TimedLru { - pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { - let (created_at, old) = self.insert_raw(key.clone(), value.clone()); - - let cached = Cached { - token: Some((self, LookupInfo { created_at, key })), - value, - }; - - (old, cached) - } - } - - impl TimedLru { - /// Retrieve a cached entry in convenient wrapper. - pub fn get(&self, key: &Q) -> Option> - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - self.get_raw(key, |key, entry| { - let info = LookupInfo { - created_at: entry.created_at, - key: key.clone(), - }; - - Cached { - token: Some((self, info)), - value: entry.value.clone(), - } - }) - } - } - - /// Lookup information for key invalidation. - pub struct LookupInfo { - /// Time of creation of a cache [`Entry`]. - /// We use this during invalidation lookups to prevent eviction of a newer - /// entry sharing the same key (it might've been inserted by a different - /// task after we got the entry we're trying to invalidate now). - created_at: Instant, - - /// Search by this key. - key: K, - } - - /// Wrapper for convenient entry invalidation. - pub struct Cached { - /// Cache + lookup info. - token: Option<(C, C::LookupInfo)>, - - /// The value itself. - pub value: C::Value, - } - - impl Cached { - /// Place any entry into this wrapper; invalidation will be a no-op. - /// Unfortunately, rust doesn't let us implement [`From`] or [`Into`]. - pub fn new_uncached(value: impl Into) -> Self { - Self { - token: None, - value: value.into(), - } - } - - /// Drop this entry from a cache if it's still there. - pub fn invalidate(&self) { - if let Some((cache, info)) = &self.token { - cache.invalidate(info); - } - } - - /// Tell if this entry is actually cached. - pub fn cached(&self) -> bool { - self.token.is_some() - } - } - - impl Deref for Cached { - type Target = C::Value; - - fn deref(&self) -> &Self::Target { - &self.value - } - } - - impl DerefMut for Cached { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.value - } + // Values are still there. + assert_eq!(*x, 0); + assert_eq!(*y, 404); } } diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs new file mode 100644 index 0000000000..efbf9bca37 --- /dev/null +++ b/proxy/src/cache/timed_lru.rs @@ -0,0 +1,272 @@ +#[cfg(test)] +mod tests; + +use super::{Cache, Cached, InvalidationToken, LookupInfo}; +use ref_cast::RefCast; +use std::{ + any::Any, + borrow::Borrow, + hash::Hash, + sync::Arc, + time::{Duration, Instant}, +}; +use tracing::debug; + +// This seems to make more sense than `lru` or `cached`: +// +// * `near/nearcore` ditched `cached` in favor of `lru` +// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). +// +// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). +// This severely hinders its usage both in terms of creating wrappers and supported key types. +// +// On the other hand, `hashlink` has good download stats and appears to be maintained. +use hashlink::{linked_hash_map::RawEntryMut, LruCache}; + +/// An implementation of timed LRU cache with fixed capacity. +/// Key properties: +/// +/// * Whenever a new entry is inserted, the least recently accessed one is evicted. +/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). +/// +/// * When the entry is about to be retrieved, we check its expiration timestamp. +/// If the entry has expired, we remove it from the cache; Otherwise we bump the +/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong +/// its existence. +/// +/// * There's an API for immediate invalidation (removal) of a cache entry; +/// It's useful in case we know for sure that the entry is no longer correct. +/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information. +/// +/// * Expired entries are kept in the cache, until they are evicted by the LRU policy, +/// or by a successful lookup (i.e. the entry hasn't expired yet). +/// There is no background job to reap the expired records. +/// +/// * It's possible for an entry that has not yet expired entry to be evicted +/// before expired items. That's a bit wasteful, but probably fine in practice. +pub struct TimedLru { + /// Cache's name for tracing. + name: &'static str, + + /// The underlying cache implementation. + cache: parking_lot::Mutex, Entry>>, + + /// Default time-to-live of a single entry. + ttl: Duration, +} + +#[derive(RefCast, Hash, PartialEq, Eq)] +#[repr(transparent)] +struct Query(Q); + +#[derive(Hash, PartialEq, Eq)] +#[repr(transparent)] +struct Key(Arc); + +/// It's impossible to implement this without [`Key`] & [`Query`]: +/// * We can't implement std traits for [`Arc`]. +/// * Even if we could, it'd conflict with `impl Borrow for T`. +impl Borrow> for Key +where + Q: ?Sized, + T: Borrow, +{ + #[inline(always)] + fn borrow(&self) -> &Query { + RefCast::ref_cast(self.0.as_ref().borrow()) + } +} + +struct Entry { + created_at: Instant, + expires_at: Instant, + value: Arc, +} + +impl TimedLru { + /// Construct a new LRU cache with timed entries. + pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self { + Self { + name, + cache: LruCache::new(capacity).into(), + ttl, + } + } + + /// Get the number of entries in the cache. + /// Note that this method will not try to evict stale entries. + pub fn size(&self) -> usize { + self.cache.lock().len() + } + + /// Try retrieving an entry by its key, then execute `extract` if it exists. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn get_raw(&self, key: &Q, extract: F) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + F: FnOnce(&Arc, &Entry) -> R, + { + let key: &Query = RefCast::ref_cast(key); + let now = Instant::now(); + let deadline = now.checked_add(self.ttl).expect("time overflow"); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let mut raw_entry = match cache.raw_entry_mut().from_key(key) { + RawEntryMut::Vacant(_) => return None, + RawEntryMut::Occupied(x) => x, + }; + + // Immeditely drop the entry if it has expired. + let entry = raw_entry.get(); + if entry.expires_at <= now { + raw_entry.remove(); + return None; + } + + let value = extract(&raw_entry.key().0, entry); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + + // Update the deadline and the entry's position in the LRU list. + raw_entry.get_mut().expires_at = deadline; + raw_entry.to_back(); + + drop(cache); // drop lock before logging + debug!( + ?created_at, + old_expires_at = ?expires_at, + new_expires_at = ?deadline, + "accessed a cache entry" + ); + + Some(value) + } + + /// Insert an entry to the cache. If an entry with the same key already + /// existed, return the previous value and its creation timestamp. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn insert_raw(&self, key: Arc, value: Arc) -> (Option>, Instant) { + let created_at = Instant::now(); + let expires_at = created_at.checked_add(self.ttl).expect("time overflow"); + + let entry = Entry { + created_at, + expires_at, + value, + }; + + // Do costly things before taking the lock. + let old = self + .cache + .lock() + .insert(Key(key), entry) + .map(|entry| entry.value); + + debug!( + ?created_at, + ?expires_at, + replaced = old.is_some(), + "created a cache entry" + ); + + (old, created_at) + } +} + +/// Convenient wrappers for raw methods. +impl TimedLru +where + Self: Sync + Send, +{ + pub fn get(&self, key: &Q) -> Option> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + self.get_raw(key, |key, entry| { + let info = LookupInfo { + created_at: entry.created_at, + key: Arc::clone(key), + }; + + Cached { + token: Some(Self::invalidation_token(self, info)), + value: entry.value.clone(), + } + }) + } + + pub fn insert(&self, key: Arc, value: Arc) -> (Option>, Cached) { + let (old, created_at) = self.insert_raw(key.clone(), value.clone()); + + let info = LookupInfo { created_at, key }; + let cached = Cached { + token: Some(Self::invalidation_token(self, info)), + value, + }; + + (old, cached) + } +} + +/// Implementation details of the entry invalidation machinery. +impl TimedLru +where + Self: Sync + Send, +{ + /// This is a proper (safe) way to create an invalidation token for [`TimedLru`]. + fn invalidation_token(&self, info: LookupInfo>) -> InvalidationToken<'_> { + InvalidationToken { + cache: self, + info: LookupInfo { + created_at: info.created_at, + key: info.key, + }, + } + } +} + +/// This implementation depends on [`Self::invalidation_token`]. +impl Cache for TimedLru { + fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) { + let info = LookupInfo { + created_at: info.created_at, + // NOTE: it's important to downcast to the correct type! + key: info.key.downcast_ref::().expect("bad key type"), + }; + self.invalidate_raw(info) + } +} + +impl TimedLru { + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn invalidate_raw(&self, info: LookupInfo<&K>) { + let key: &Query = RefCast::ref_cast(info.key); + let now = Instant::now(); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let raw_entry = match cache.raw_entry_mut().from_key(key) { + RawEntryMut::Vacant(_) => return, + RawEntryMut::Occupied(x) => x, + }; + + // Remove the entry if it was created prior to lookup timestamp. + let entry = raw_entry.get(); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + let should_remove = created_at <= info.created_at || expires_at <= now; + + if should_remove { + raw_entry.remove(); + } + + drop(cache); // drop lock before logging + debug!( + ?created_at, + ?expires_at, + entry_removed = should_remove, + "processed a cache entry invalidation event" + ); + } +} diff --git a/proxy/src/cache/timed_lru/tests.rs b/proxy/src/cache/timed_lru/tests.rs new file mode 100644 index 0000000000..4980a8071b --- /dev/null +++ b/proxy/src/cache/timed_lru/tests.rs @@ -0,0 +1,104 @@ +use super::*; + +/// Check that we can define the cache for certain types. +#[test] +fn definition() { + // Check for trivial yet possible types. + let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0)); + let _ = cache.insert(Default::default(), Default::default()); + let _ = cache.get(&()); + + // Now for something less trivial. + let cache = TimedLru::::new("test", 128, Duration::from_secs(0)); + let _ = cache.insert(Default::default(), Default::default()); + let _ = cache.get(&String::default()); + let _ = cache.get("str should work"); + + // It should also work for non-cloneable values. + struct NoClone; + let cache = TimedLru::, NoClone>::new("test", 128, Duration::from_secs(0)); + let _ = cache.insert(Default::default(), NoClone.into()); + let _ = cache.get(&Box::::from("boxed str")); + let _ = cache.get("str should work"); +} + +#[test] +fn insert() { + const CAPACITY: usize = 2; + let cache = TimedLru::::new("test", CAPACITY, Duration::from_secs(10)); + assert_eq!(cache.size(), 0); + + let key = Arc::new(String::from("key")); + + let (old, cached) = cache.insert(key.clone(), 42.into()); + assert_eq!(old, None); + assert_eq!(*cached, 42); + assert_eq!(cache.size(), 1); + + let (old, cached) = cache.insert(key, 1.into()); + assert_eq!(old.as_deref(), Some(&42)); + assert_eq!(*cached, 1); + assert_eq!(cache.size(), 1); + + let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into()); + assert_eq!(old, None); + assert_eq!(*cached, 10); + assert_eq!(cache.size(), 2); + + let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into()); + assert_eq!(old, None); + assert_eq!(*cached, 20); + assert_eq!(cache.size(), 2); +} + +#[test] +fn get_none() { + let cache = TimedLru::::new("test", 2, Duration::from_secs(10)); + let cached = cache.get("missing"); + assert!(matches!(cached, None)); +} + +#[test] +fn invalidation_simple() { + let cache = TimedLru::::new("test", 2, Duration::from_secs(10)); + let (_, cached) = cache.insert(String::from("key").into(), 100.into()); + assert_eq!(cache.size(), 1); + + cached.invalidate(); + + assert_eq!(cache.size(), 0); + assert!(matches!(cache.get("key"), None)); +} + +#[test] +fn invalidation_preserve_newer() { + let cache = TimedLru::::new("test", 2, Duration::from_secs(10)); + let key = Arc::new(String::from("key")); + + let (_, cached) = cache.insert(key.clone(), 100.into()); + assert_eq!(cache.size(), 1); + let _ = cache.insert(key.clone(), 200.into()); + assert_eq!(cache.size(), 1); + cached.invalidate(); + assert_eq!(cache.size(), 1); + + let cached = cache.get(key.as_ref()); + assert_eq!(cached.as_deref(), Some(&200)); +} + +#[test] +fn auto_expiry() { + let lifetime = Duration::from_millis(300); + let cache = TimedLru::::new("test", 2, lifetime); + + let key = Arc::new(String::from("key")); + let _ = cache.insert(key.clone(), 42.into()); + + let cached = cache.get(key.as_ref()); + assert_eq!(cached.as_deref(), Some(&42)); + + std::thread::sleep(lifetime); + + let cached = cache.get(key.as_ref()); + assert_eq!(cached.as_deref(), None); +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 480acb88d9..7bcf3120c0 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,13 +1,28 @@ -use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError}; +use crate::{ + auth::parse_endpoint_param, + cancellation::CancelClosure, + console::messages::{DatabaseInfo, MetricsAuxInfo}, + console::CachedNodeInfo, + error::UserFacingError, +}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use pq_proto::StartupMessageParams; -use std::{io, net::SocketAddr, time::Duration}; +use std::{ + io, + net::SocketAddr, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::tls::MakeTlsConnect; use tracing::{error, info, warn}; +/// Should we allow self-signed certificates in TLS connections? +/// Most definitely, this shouldn't be allowed in production. +pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false); + const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] @@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError { /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; +#[derive(Clone)] +pub enum Password { + /// A regular cleartext password. + ClearText(Vec), + /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. + ScramKeys(ScramKeys), +} + +pub enum ComputeNode { + /// Route via link auth. + Link(DatabaseInfo), + /// Regular compute node. + Static { + password: Password, + info: CachedNodeInfo, + }, +} + +impl ComputeNode { + /// Get metrics auxiliary info. + pub fn metrics_aux_info(&self) -> &MetricsAuxInfo { + match self { + Self::Link(info) => &info.aux, + Self::Static { info, .. } => &info.aux, + } + } + + /// Invalidate compute node info if it's cached. + pub fn invalidate(&self) -> bool { + if let Self::Static { info, .. } = self { + warn!("invalidating compute node info cache entry"); + info.invalidate(); + return true; + } + + false + } + + /// Turn compute node info into a postgres connection config. + pub fn to_conn_config(&self) -> ConnCfg { + let mut config = ConnCfg::new(); + + let (host, port) = match self { + Self::Link(info) => { + // NB: use pre-supplied dbname, user and password for link auth. + // See `ConnCfg::set_startup_params` below. + config.0.dbname(&info.dbname).user(&info.user); + if let Some(password) = &info.password { + config.0.password(password.as_bytes()); + } + + (&info.host, info.port) + } + Self::Static { info, password } => { + // NB: setup auth keys (for SCRAM) or plaintext password. + match password { + Password::ClearText(text) => config.0.password(text), + Password::ScramKeys(keys) => { + use tokio_postgres::config::AuthKeys; + config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned())) + } + }; + + (&info.address.host, info.address.port) + } + }; + + // Backwards compatibility. pg_sni_proxy uses "--" in domain names + // while direct connections do not. Once we migrate to pg_sni_proxy + // everywhere, we can remove this. + config.0.ssl_mode(if host.contains("--") { + // We need TLS connection with SNI info to properly route it. + tokio_postgres::config::SslMode::Require + } else { + tokio_postgres::config::SslMode::Disable + }); + + config.0.host(host).port(port); + config + } +} + /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. @@ -51,43 +148,32 @@ pub struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { - pub fn new() -> Self { + fn new() -> Self { Self(Default::default()) } - /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: &Self) { - if let Some(password) = other.get_password() { - self.password(password); - } - - if let Some(keys) = other.get_auth_keys() { - self.auth_keys(keys); - } - } - /// Apply startup message params to the connection config. pub fn set_startup_params(&mut self, params: &StartupMessageParams) { // Only set `user` if it's not present in the config. // Link auth flow takes username from the console's response. - if let (None, Some(user)) = (self.get_user(), params.get("user")) { - self.user(user); + if let (None, Some(user)) = (self.0.get_user(), params.get("user")) { + self.0.user(user); } // Only set `dbname` if it's not present in the config. // Link auth flow takes dbname from the console's response. - if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { - self.dbname(dbname); + if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) { + self.0.dbname(dbname); } // Don't add `options` if they were only used for specifying a project. // Connection pools don't support `options`, because they affect backend startup. if let Some(options) = filtered_options(params) { - self.options(&options); + self.0.options(&options); } if let Some(app_name) = params.get("application_name") { - self.application_name(app_name); + self.0.application_name(app_name); } // TODO: This is especially ugly... @@ -95,10 +181,10 @@ impl ConnCfg { use tokio_postgres::config::ReplicationMode; match replication { "true" | "on" | "yes" | "1" => { - self.replication_mode(ReplicationMode::Physical); + self.0.replication_mode(ReplicationMode::Physical); } "database" => { - self.replication_mode(ReplicationMode::Logical); + self.0.replication_mode(ReplicationMode::Logical); } _other => {} } @@ -113,27 +199,6 @@ impl ConnCfg { } } -impl std::ops::Deref for ConnCfg { - type Target = tokio_postgres::Config; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// For now, let's make it easier to setup the config. -impl std::ops::DerefMut for ConnCfg { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for ConnCfg { - fn default() -> Self { - Self::new() - } -} - impl ConnCfg { /// Establish a raw TCP connection to the compute node. async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> { @@ -220,16 +285,13 @@ pub struct PostgresConnection { } impl ConnCfg { - async fn do_connect( - &self, - allow_self_signed_compute: bool, - ) -> Result { + async fn do_connect(&self) -> Result { let (socket_addr, stream, host) = self.connect_raw().await?; let tls_connector = native_tls::TlsConnector::builder() - .danger_accept_invalid_certs(allow_self_signed_compute) - .build() - .unwrap(); + .danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed)) + .build()?; + let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector); let tls = MakeTlsConnect::::make_tls_connect(&mut mk_tls, host)?; @@ -261,11 +323,8 @@ impl ConnCfg { } /// Connect to a corresponding compute node. - pub async fn connect( - &self, - allow_self_signed_compute: bool, - ) -> Result { - self.do_connect(allow_self_signed_compute) + pub async fn connect(&self) -> Result { + self.do_connect() .inspect_err(|err| { // Immediately log the error we have at our disposal. error!("couldn't connect to compute node: {err}"); diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6a26cea78e..497a8916e5 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -12,7 +12,6 @@ pub struct ProxyConfig { pub tls_config: Option, pub auth_backend: auth::BackendType<'static, ()>, pub metric_collection: Option, - pub allow_self_signed_compute: bool, } #[derive(Debug)] diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 0d321c077a..1ef64eca83 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret { } } +/// Represents compute node's host & port pair. +#[derive(Debug)] +pub struct PgEndpoint { + pub host: Box, + pub port: u16, +} + +impl<'de> Deserialize<'de> for PgEndpoint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let raw = String::deserialize(deserializer)?; + let (host, port) = raw + .rsplit_once(':') + .ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?; + + Ok(PgEndpoint { + host: host.into(), + port: port.parse().map_err(Error::custom)?, + }) + } +} + /// Response which holds compute node's `host:port` pair. /// Returned by the `/proxy_wake_compute` API method. #[derive(Debug, Deserialize)] pub struct WakeCompute { - pub address: Box, + pub address: PgEndpoint, pub aux: MetricsAuxInfo, } @@ -187,4 +213,24 @@ mod tests { Ok(()) } + + #[test] + fn parse_wake_compute() -> anyhow::Result<()> { + let _: WakeCompute = serde_json::from_value(json!({ + "address": "127.0.0.1:5432", + "aux": dummy_aux(), + }))?; + + let _: WakeCompute = serde_json::from_value(json!({ + "address": "[::1]:5432", + "aux": dummy_aux(), + }))?; + + serde_json::from_value::(json!({ + "address": "localhost:5432", + "aux": dummy_aux(), + }))?; + + Ok(()) + } } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 44e23e0adf..ce84aa0b40 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -1,14 +1,13 @@ pub mod mock; pub mod neon; -use super::messages::MetricsAuxInfo; +use super::messages::{MetricsAuxInfo, PgEndpoint}; use crate::{ auth::ClientCredentials, - cache::{timed_lru, TimedLru}, - compute, scram, + cache::{types::Cached, TimedLru}, + scram, }; use async_trait::async_trait; -use std::sync::Arc; pub mod errors { use crate::{ @@ -112,11 +111,9 @@ pub mod errors { } } } + #[derive(Debug, Error)] pub enum WakeComputeError { - #[error("Console responded with a malformed compute address: {0}")] - BadComputeAddress(Box), - #[error(transparent)] ApiError(ApiError), } @@ -132,10 +129,7 @@ pub mod errors { fn to_string_client(&self) -> String { use WakeComputeError::*; match self { - // We shouldn't show user the address even if it's broken. - // Besides, user is unlikely to care about this detail. - BadComputeAddress(_) => REQUEST_FAILED.to_owned(), - // However, API might return a meaningful error. + // API might return a meaningful error. ApiError(e) => e.to_string_client(), } } @@ -160,23 +154,17 @@ pub enum AuthInfo { } /// Info for establishing a connection to a compute node. -/// This is what we get after auth succeeded, but not before! -#[derive(Clone)] +/// This struct is cached, so we shouldn't store any user creds here. pub struct NodeInfo { - /// Compute node connection params. - /// It's sad that we have to clone this, but this will improve - /// once we migrate to a bespoke connection logic. - pub config: compute::ConnCfg, + /// Address of the compute node. + pub address: PgEndpoint, /// Labels for proxy's metrics. - pub aux: Arc, - - /// Whether we should accept self-signed certificates (for testing) - pub allow_self_signed_compute: bool, + pub aux: MetricsAuxInfo, } -pub type NodeInfoCache = TimedLru, NodeInfo>; -pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; +pub type NodeInfoCache = TimedLru, NodeInfo>; +pub type CachedNodeInfo = Cached; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 3b42c73a34..03eac37b26 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -2,13 +2,12 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, - AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, PgEndpoint, }; -use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; +use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl}; use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; -use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; #[derive(Debug, Error)] @@ -84,19 +83,15 @@ impl Api { } async fn do_wake_compute(&self) -> Result { - let mut config = compute::ConnCfg::new(); - config - .host(self.endpoint.host_str().unwrap_or("localhost")) - .port(self.endpoint.port().unwrap_or(5432)) - .ssl_mode(SslMode::Disable); + let host = self.endpoint.host_str().unwrap_or("localhost").into(); + let port = self.endpoint.port().unwrap_or(5432); - let node = NodeInfo { - config, + let info = NodeInfo { + address: PgEndpoint { host, port }, aux: Default::default(), - allow_self_signed_compute: false, }; - Ok(node) + Ok(info) } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index a8e855b2c8..7cd2af9856 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -5,10 +5,9 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{auth::ClientCredentials, compute, http, scram}; +use crate::{auth::ClientCredentials, http, scram}; use async_trait::async_trait; use futures::TryFutureExt; -use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; #[derive(Clone)] @@ -91,22 +90,9 @@ impl Api { let response = self.endpoint.execute(request).await?; let body = parse_body::(response).await?; - // Unfortunately, ownership won't let us use `Option::ok_or` here. - let (host, port) = match parse_host_port(&body.address) { - None => return Err(WakeComputeError::BadComputeAddress(body.address)), - Some(x) => x, - }; - - // Don't set anything but host and port! This config will be cached. - // We'll set username and such later using the startup message. - // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(); - config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. - let node = NodeInfo { - config, - aux: body.aux.into(), - allow_self_signed_compute: false, + address: body.address, + aux: body.aux, }; Ok(node) @@ -141,13 +127,15 @@ impl super::Api for Api { // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. if let Some(cached) = self.caches.node_info.get(key) { - info!(key = key, "found cached compute node info"); + info!(key, "found cached compute node info"); return Ok(cached); } - let node = self.do_wake_compute(extra, creds).await?; - let (_, cached) = self.caches.node_info.insert(key.into(), node); - info!(key = key, "created a cache entry for compute node info"); + let info = self.do_wake_compute(extra, creds).await?; + + let owned_key = Box::::from(key); + let (_, cached) = self.caches.node_info.insert(owned_key.into(), info.into()); + info!(key, "created a cache entry for compute node info"); Ok(cached) } @@ -177,20 +165,3 @@ async fn parse_body serde::Deserialize<'a>>( error!("console responded with an error ({status}): {text}"); Err(ApiError::Console { status, text }) } - -fn parse_host_port(input: &str) -> Option<(&str, u16)> { - let (host, port) = input.split_once(':')?; - Some((host, port.parse().ok()?)) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_host_port() { - let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse"); - assert_eq!(host, "127.0.0.1"); - assert_eq!(port, 5432); - } -} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index cf2dd000db..3de25881b6 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -4,7 +4,7 @@ mod tests; use crate::{ auth::{self, backend::AuthSuccess}, cancellation::{self, CancelMap}, - compute::{self, PostgresConnection}, + compute::{self, ComputeNode, PostgresConnection}, config::{ProxyConfig, TlsConfig}, console::{self, messages::MetricsAuxInfo}, error::io_error, @@ -155,7 +155,7 @@ pub async fn handle_ws_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds, ¶ms, session_id, false); + let client = Client::new(stream, creds, ¶ms, session_id); cancel_map .with_session(|session| client.connect_to_db(session, true)) .await @@ -194,15 +194,7 @@ async fn handle_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let allow_self_signed_compute = config.allow_self_signed_compute; - - let client = Client::new( - stream, - creds, - ¶ms, - session_id, - allow_self_signed_compute, - ); + let client = Client::new(stream, creds, ¶ms, session_id); cancel_map .with_session(|session| client.connect_to_db(session, false)) .await @@ -283,61 +275,38 @@ async fn handshake( } } -/// Try to connect to the compute node once. -#[tracing::instrument(name = "connect_once", skip_all)] -async fn connect_to_compute_once( - node_info: &console::CachedNodeInfo, -) -> Result { - // If we couldn't connect, a cached connection info might be to blame - // (e.g. the compute node's address might've changed at the wrong time). - // Invalidate the cache entry (if any) to prevent subsequent errors. - let invalidate_cache = |_: &compute::ConnectionError| { - let is_cached = node_info.cached(); - if is_cached { - warn!("invalidating stalled compute node info cache entry"); - node_info.invalidate(); - } - - let label = match is_cached { - true => "compute_cached", - false => "compute_uncached", - }; - NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); - }; - - let allow_self_signed_compute = node_info.allow_self_signed_compute; - - node_info - .config - .connect(allow_self_signed_compute) - .inspect_err(invalidate_cache) - .await -} - /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] async fn connect_to_compute( - node_info: &mut console::CachedNodeInfo, + node_info: &mut ComputeNode, params: &StartupMessageParams, extra: &console::ConsoleReqExtra<'_>, creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, ) -> Result { - let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; + let mut num_retries = NUM_RETRIES_WAKE_COMPUTE; loop { - // Apply startup params to the (possibly, cached) compute node info. - node_info.config.set_startup_params(params); - match connect_to_compute_once(node_info).await { + let mut config = node_info.to_conn_config(); + config.set_startup_params(params); + match config.connect().await { Err(e) if num_retries > 0 => { - info!("compute node's state has changed; requesting a wake-up"); - match creds.wake_compute(extra).map_err(io_error).await? { + let label = match node_info.invalidate() { + true => "compute_cached", + false => "compute_uncached", + }; + NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); + + let res = creds.wake_compute(extra).map_err(io_error).await?; + match (res, &node_info) { // Update `node_info` and try one more time. - Some(mut new) => { - new.config.reuse_password(&node_info.config); - *node_info = new; + (Some(new), ComputeNode::Static { password, .. }) => { + *node_info = ComputeNode::Static { + password: password.to_owned(), + info: new, + } } // Link auth doesn't work that way, so we just exit. - None => return Err(e), + _ => return Err(e), } } other => return other, @@ -430,8 +399,6 @@ struct Client<'a, S> { params: &'a StartupMessageParams, /// Unique connection ID. session_id: uuid::Uuid, - /// Allow self-signed certificates (for testing). - allow_self_signed_compute: bool, } impl<'a, S> Client<'a, S> { @@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> { creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, params: &'a StartupMessageParams, session_id: uuid::Uuid, - allow_self_signed_compute: bool, ) -> Self { Self { stream, creds, params, session_id, - allow_self_signed_compute, } } } @@ -468,7 +433,6 @@ impl Client<'_, S> { mut creds, params, session_id, - allow_self_signed_compute, } = self; let extra = console::ConsoleReqExtra { @@ -491,19 +455,19 @@ impl Client<'_, S> { value: mut node_info, } = auth_result; - node_info.allow_self_signed_compute = allow_self_signed_compute; - let mut node = connect_to_compute(&mut node_info, params, &extra, &creds) .or_else(|e| stream.throw_error(e)) .await?; prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?; + // Before proxy passing, forward to compute whatever data is left in the // PqStream input buffer. Normally there is none, but our serverless npm // driver in pipeline mode sends startup, password and first query // immediately after opening the connection. let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(stream, node.stream, &node_info.aux).await + + proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await } }