Compare commits

...

14 Commits

Author SHA1 Message Date
Vadim Kharitonov
8f49af45a7 Try to fix test 2023-05-24 15:39:39 +02:00
Vadim Kharitonov
c7872d3f6d Make clippy happy 2023-05-24 14:18:32 +02:00
Vadim Kharitonov
c82b00110e Fix SQL over HTTP endpoint 2023-05-24 14:05:03 +02:00
Vadim Kharitonov
d425f2cb14 Fix ruff after rebase 2023-05-24 13:56:57 +02:00
Dmitry Ivanov
2b0066f67c Minor edits 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
01be823084 Tune console notifications handler 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
69bb1eee18 Update workspace_hack 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
0c040aca63 Fix mypy warnings 2023-05-24 12:54:25 +02:00
Dmitry Ivanov
d6c7b4d994 Fix formatting 2023-05-24 12:53:23 +02:00
Dmitry Ivanov
251e410add Implement console notifications listener 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
ce416c8160 Implement a cache for get_auth_info method 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
f7e9ec49be Add stub for auth info cache 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
2b60ad0285 Test compute node cache invalidation 2023-05-24 12:51:00 +02:00
Dmitry Ivanov
9b99d4caa9 New cache impl 2023-05-24 12:45:57 +02:00
24 changed files with 1166 additions and 619 deletions

66
Cargo.lock generated
View File

@@ -905,6 +905,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "combine"
version = "4.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "comfy-table"
version = "6.1.4"
@@ -3104,6 +3118,8 @@ dependencies = [
"prometheus",
"rand",
"rcgen",
"redis",
"ref-cast",
"regex",
"reqwest",
"reqwest-middleware",
@@ -3210,6 +3226,29 @@ dependencies = [
"yasna",
]
[[package]]
name = "redis"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ea8c51b5dc1d8e5fd3350ec8167f464ec0995e79f2e90a075b63371500d557f"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"rustls 0.21.0",
"rustls-native-certs",
"ryu",
"sha1_smol",
"tokio",
"tokio-rustls 0.24.0",
"tokio-util",
"url",
]
[[package]]
name = "redox_syscall"
version = "0.2.16"
@@ -3228,6 +3267,26 @@ dependencies = [
"bitflags",
]
[[package]]
name = "ref-cast"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]]
name = "regex"
version = "1.7.3"
@@ -3844,6 +3903,12 @@ dependencies = [
"digest",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]]
name = "sha2"
version = "0.10.6"
@@ -5326,6 +5391,7 @@ dependencies = [
"reqwest",
"ring",
"rustls 0.20.8",
"rustls 0.21.0",
"scopeguard",
"serde",
"serde_json",

View File

@@ -77,10 +77,12 @@ pin-project-lite = "0.2"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
ref-cast = "1.0"
redis = { version = "0.23.0", features = ["tokio-rustls-comp"] }
regex = "1.4"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
reqwest-middleware = "0.2.0"
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
routerify = "3"
rpds = "0.13"
rustls = "0.20"

View File

@@ -35,7 +35,9 @@ postgres_backend.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true
ref-cast.workspace = true
regex.workspace = true
redis.workspace = true
reqwest = { workspace = true, features = ["json"] }
reqwest-middleware.workspace = true
reqwest-tracing.workspace = true
@@ -50,9 +52,9 @@ socket2.workspace = true
sync_wrapper.workspace = true
thiserror.workspace = true
tls-listener.workspace = true
tokio = { workspace = true, features = ["signal"] }
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true

View File

@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
use crate::{
auth::{self, ClientCredentials},
compute::ComputeNode,
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
@@ -114,7 +115,7 @@ async fn auth_quirks(
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
use BackendType::*;
let res = match self {
@@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
Link(url) => {
info!("performing link authentication");
link::authenticate(url, client)
.await?
.map(CachedNodeInfo::new_uncached)
link::authenticate(url, client).await?
}
};

View File

@@ -1,8 +1,8 @@
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
compute::{self, ComputeNode, Password},
console::{self, AuthInfo, CachedAuthInfo, ConsoleReqExtra},
sasl, scram,
stream::PqStream,
};
@@ -14,25 +14,26 @@ pub(super) async fn authenticate(
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
let info = scram::ServerSecret::mock(creds.user, rand::random());
CachedAuthInfo::new_uncached(AuthInfo::Scram(info))
});
let flow = AuthFlow::new(client);
let scram_keys = match info {
let keys = match &*info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let scram = auth::Scram(secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
@@ -41,21 +42,20 @@ pub(super) async fn authenticate(
}
};
Some(compute::ScramKeys {
compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
}
};
let mut node = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
let info = api.wake_compute(extra, creds).await?;
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ScramKeys(keys),
info,
},
})
}

View File

@@ -1,10 +1,8 @@
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
},
compute::{ComputeNode, Password},
console::{self, provider::ConsoleReqExtra},
stream,
};
use tokio::io::{AsyncRead, AsyncWrite};
@@ -19,7 +17,7 @@ pub async fn cleartext_hack(
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client)
.begin(auth::CleartextPassword)
@@ -27,13 +25,15 @@ pub async fn cleartext_hack(
.authenticate()
.await?;
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
let info = api.wake_compute(extra, creds).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ClearText(password),
info,
},
})
}
@@ -44,7 +44,7 @@ pub async fn password_hack(
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -55,12 +55,14 @@ pub async fn password_hack(
info!(project = &payload.endpoint, "received missing parameter");
creds.project = Some(payload.endpoint);
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
let info = api.wake_compute(extra, creds).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
value: ComputeNode::Static {
password: Password::ClearText(payload.password),
info,
},
})
}

View File

@@ -1,15 +1,10 @@
use super::AuthSuccess;
use crate::{
auth, compute,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters,
};
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
#[derive(Debug, Error)]
@@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String {
pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
) -> auth::Result<AuthSuccess<ComputeNode>> {
let psql_session_id = new_psql_session_id();
let span = info_span!("link", psql_session_id = &psql_session_id);
let span = info_span!("link", psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id);
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
let info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
@@ -79,35 +74,8 @@ pub(super) async fn authenticate(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok(AuthSuccess {
reported_auth_ok: true,
value: NodeInfo {
config,
aux: db_info.aux.into(),
allow_self_signed_compute: false, // caller may override
},
value: ComputeNode::Link(info),
})
}

View File

@@ -1,16 +1,15 @@
use proxy::auth;
use proxy::console;
use proxy::http;
use proxy::metrics;
use anyhow::bail;
use anyhow::{bail, Context};
use clap::{self, Arg};
use proxy::config::{self, ProxyConfig};
use std::{borrow::Cow, net::SocketAddr};
use futures::future::try_join_all;
use proxy::{
auth,
config::{self, MetricCollectionConfig, ProxyConfig, TlsConfig},
console, http, metrics,
};
use std::{borrow::Cow, net::SocketAddr, sync::atomic::Ordering};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use tracing::{info, warn};
use utils::{project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
@@ -25,8 +24,7 @@ async fn main() -> anyhow::Result<()> {
::metrics::set_build_info_metric(GIT_VERSION);
let args = cli().get_matches();
let config = build_config(&args)?;
let config: &ProxyConfig = Box::leak(Box::new(build_config(&args)?));
info!("Authentication backend: {}", config.auth_backend);
// Check that we can bind to address before further initialization
@@ -71,20 +69,29 @@ async fn main() -> anyhow::Result<()> {
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
}
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err));
let client_tasks =
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
if let auth::BackendType::Console(api, _) = &config.auth_backend {
if let Some(url) = args.get_one::<String>("redis-notifications") {
info!("Starting redis notifications listener ({url})");
tasks.push(tokio::spawn(console::notifications::task_main(
url.to_owned(),
api.caches,
)));
}
}
let tasks = try_join_all(tasks.into_iter().map(proxy::flatten_err));
let client_tasks = try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
tokio::select! {
// We are only expecting an error from these forever tasks
res = tasks => { res?; },
res = client_tasks => { res?; },
}
Ok(())
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
let tls_config = match (
fn build_tls_config(args: &clap::ArgMatches) -> anyhow::Result<Option<TlsConfig>> {
let config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -101,16 +108,22 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
.get_one::<String>("allow-self-signed-compute")
.unwrap()
.parse()?;
if allow_self_signed_compute {
warn!("allowing self-signed compute certificates");
proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed);
}
let metric_collection = match (
args.get_one::<String>("metric-collection-endpoint"),
args.get_one::<String>("metric-collection-interval"),
) {
Ok(config)
}
fn build_metrics_config(args: &clap::ArgMatches) -> anyhow::Result<Option<MetricCollectionConfig>> {
let endpoint = args.get_one::<String>("metric-collection-endpoint");
let interval = args.get_one::<String>("metric-collection-interval");
let config = match (endpoint, interval) {
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
endpoint: endpoint.parse()?,
endpoint: endpoint.parse().context("bad metrics endpoint")?,
interval: humantime::parse_duration(interval)?,
}),
(None, None) => None,
@@ -120,21 +133,40 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
),
};
let auth_backend = match args.get_one::<String>("auth-backend").unwrap().as_str() {
Ok(config)
}
fn make_caches(args: &clap::ArgMatches) -> anyhow::Result<console::caches::ApiCaches> {
let config::CacheOptions { size, ttl } = args
.get_one::<String>("get-auth-info-cache")
.unwrap()
.parse()?;
info!("Using AuthInfoCache (get_auth_info) with size={size} ttl={ttl:?}");
let auth_info = console::caches::AuthInfoCache::new("auth_info_cache", size, ttl);
let config::CacheOptions { size, ttl } = args
.get_one::<String>("wake-compute-cache")
.unwrap()
.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
let node_info = console::caches::NodeInfoCache::new("node_info_cache", size, ttl);
let caches = console::caches::ApiCaches {
auth_info,
node_info,
};
Ok(caches)
}
fn build_auth_config(args: &clap::ArgMatches) -> anyhow::Result<auth::BackendType<'static, ()>> {
let config = match args.get_one::<String>("auth-backend").unwrap().as_str() {
"console" => {
let config::CacheOptions { size, ttl } = args
.get_one::<String>("wake-compute-cache")
.unwrap()
.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches {
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
}));
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let caches = Box::leak(Box::new(make_caches(args)?));
let api = console::provider::neon::Api::new(endpoint, caches);
auth::BackendType::Console(Cow::Owned(api), ())
}
@@ -150,12 +182,16 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
other => bail!("unsupported auth backend: {other}"),
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute,
}));
Ok(config)
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<ProxyConfig> {
let config = ProxyConfig {
tls_config: build_tls_config(args)?,
auth_backend: build_auth_config(args)?,
metric_collection: build_metrics_config(args)?,
};
Ok(config)
}
@@ -239,11 +275,22 @@ fn cli() -> clap::Command {
.long("metric-collection-interval")
.help("how often metrics should be sent to a collection endpoint"),
)
.arg(
Arg::new("redis-notifications")
.long("redis-notifications")
.help("for receiving notifications from console (e.g. redis://127.0.0.1:6379)"),
)
.arg(
Arg::new("get-auth-info-cache")
.long("get-auth-info-cache")
.help("cache for `get_auth_info` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_AUTH_INFO),
)
.arg(
Arg::new("wake-compute-cache")
.long("wake-compute-cache")
.help("cache for `wake_compute` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO),
.default_value(config::CacheOptions::DEFAULT_NODE_INFO),
)
.arg(
Arg::new("allow-self-signed-compute")

View File

@@ -1,304 +1,117 @@
use std::{
borrow::Borrow,
hash::Hash,
ops::{Deref, DerefMut},
time::{Duration, Instant},
};
use tracing::debug;
use std::{any::Any, sync::Arc, time::Instant};
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// A variant of LRU where every entry has a TTL.
pub mod timed_lru;
pub use timed_lru::TimedLru;
/// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation.
/// This is useful for [`timed_lru::Cached`].
pub trait Cache {
/// Entry's key.
type Key;
/// Useful type aliases.
pub mod types {
pub type Cached<T> = super::Cached<'static, T>;
}
/// Entry's value.
type Value;
/// Lookup information for cache entry invalidation.
#[derive(Clone)]
pub struct LookupInfo<K> {
/// Cache entry creation time.
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Used for entry invalidation.
type LookupInfo<Key>;
/// Search by this key.
key: K,
}
/// This type incapsulates everything needed for cache entry invalidation.
/// For convenience, we completely erase the types of a cache ref and a key.
/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec).
#[derive(Clone)]
pub struct InvalidationToken<'a> {
// TODO: allow more than one type of references (e.g. Arc) if it's ever needed.
cache: &'a (dyn Cache + Sync + Send),
info: LookupInfo<Arc<dyn Any + Sync + Send>>,
}
impl InvalidationToken<'_> {
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
let info = LookupInfo {
created_at: self.info.created_at,
key: self.info.key.as_ref(),
};
self.cache.invalidate_entry(info);
}
}
/// A combination of a cache entry and its invalidation token.
/// Makes it easier to see how those two are connected.
#[derive(Clone)]
pub struct Cached<'a, T> {
pub token: Option<InvalidationToken<'a>>,
pub value: Arc<T>,
}
impl<T> Cached<'_, T> {
/// Place any entry into this wrapper; invalidation will be a no-op.
pub fn new_uncached(value: T) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
if let Some(token) = &self.token {
token.invalidate();
}
}
}
impl<T> std::ops::Deref for Cached<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
/// This trait captures the notion of cache entry invalidation.
/// It doesn't have any associated types because we use dyn-based type erasure.
trait Cache {
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
C::invalidate(self, info)
}
}
pub use timed_lru::TimedLru;
pub mod timed_lru {
#[cfg(test)]
mod tests {
use super::*;
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
#[test]
fn trivial_properties_of_cached() {
let cached = Cached::new_uncached(0);
assert_eq!(*cached, 0);
cached.invalidate();
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
#[test]
fn invalidation_token_type_erasure() {
let lifetime = std::time::Duration::from_secs(10);
let foo = TimedLru::<u32, u32>::new("foo", 128, lifetime);
let bar = TimedLru::<String, usize>::new("bar", 128, lifetime);
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info)
}
}
let (_, x) = foo.insert(100.into(), 0.into());
let (_, y) = bar.insert(String::new().into(), 404.into());
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
// Invalidation tokens should be cloneable and homogeneous (same type).
let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()];
for token in tokens {
token.invalidate();
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
value,
};
(old, cached)
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper.
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
})
}
}
/// Lookup information for key invalidation.
pub struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}
/// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache> {
/// Cache + lookup info.
token: Option<(C, C::LookupInfo<C::Key>)>,
/// The value itself.
pub value: C::Value,
}
impl<C: Cache> Cached<C> {
/// Place any entry into this wrapper; invalidation will be a no-op.
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Drop this entry from a cache if it's still there.
pub fn invalidate(&self) {
if let Some((cache, info)) = &self.token {
cache.invalidate(info);
}
}
/// Tell if this entry is actually cached.
pub fn cached(&self) -> bool {
self.token.is_some()
}
}
impl<C: Cache> Deref for Cached<C> {
type Target = C::Value;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<C: Cache> DerefMut for Cached<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
// Values are still there.
assert_eq!(*x, 0);
assert_eq!(*y, 404);
}
}

293
proxy/src/cache/timed_lru.rs vendored Normal file
View File

@@ -0,0 +1,293 @@
#[cfg(test)]
mod tests;
use super::{Cache, Cached, InvalidationToken, LookupInfo};
use ref_cast::RefCast;
use std::{
any::Any,
borrow::Borrow,
hash::Hash,
sync::Arc,
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<Key<K>, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
}
#[derive(RefCast, Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Query<Q: ?Sized>(Q);
#[derive(Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Key<T: ?Sized>(Arc<T>);
/// It's impossible to implement this without [`Key`] & [`Query`]:
/// * We can't implement std traits for [`Arc`].
/// * Even if we could, it'd conflict with `impl<T> Borrow<T> for T`.
impl<Q, T> Borrow<Query<Q>> for Key<T>
where
Q: ?Sized,
T: Borrow<Q>,
{
#[inline(always)]
fn borrow(&self) -> &Query<Q> {
RefCast::ref_cast(self.0.as_ref().borrow())
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: Arc<T>,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
}
/// Get the number of entries in the cache.
/// Note that this method will not try to evict stale entries.
pub fn size(&self) -> usize {
self.cache.lock().len()
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, F, R>(&self, key: &Q, extract: F) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
F: FnOnce(&Arc<K>, &Entry<V>) -> R,
{
let key: &Query<Q> = RefCast::ref_cast(key);
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(&raw_entry.key().0, entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
?created_at,
old_expires_at = ?expires_at,
new_expires_at = ?deadline,
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Instant) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(Key(key), entry)
.map(|entry| entry.value);
debug!(
?created_at,
?expires_at,
replaced = old.is_some(),
"created a cache entry"
);
(old, created_at)
}
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn remove_raw<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let key: &Query<Q> = RefCast::ref_cast(key);
let mut cache = self.cache.lock();
cache.remove(key);
debug!("removed a cache entry");
}
}
/// Convenient wrappers for raw methods.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
pub fn get<Q>(&self, key: &Q) -> Option<Cached<V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: Arc::clone(key),
};
Cached {
token: Some(Self::invalidation_token(self, info)),
value: entry.value.clone(),
}
})
}
pub fn insert(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Cached<V>) {
let (old, created_at) = self.insert_raw(key.clone(), value.clone());
let info = LookupInfo { created_at, key };
let cached = Cached {
token: Some(Self::invalidation_token(self, info)),
value,
};
(old, cached)
}
pub fn remove<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.remove_raw(key)
}
}
/// Implementation details of the entry invalidation machinery.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
/// This is a proper (safe) way to create an invalidation token for [`TimedLru`].
fn invalidation_token(&self, info: LookupInfo<Arc<K>>) -> InvalidationToken<'_> {
InvalidationToken {
cache: self,
info: LookupInfo {
created_at: info.created_at,
key: info.key,
},
}
}
}
/// This implementation depends on [`Self::invalidation_token`].
impl<K: Hash + Eq + 'static, V> Cache for TimedLru<K, V> {
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) {
let info = LookupInfo {
created_at: info.created_at,
// NOTE: it's important to downcast to the correct type!
key: info.key.downcast_ref::<K>().expect("bad key type"),
};
self.invalidate_raw(info)
}
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: LookupInfo<&K>) {
let key: &Query<K> = RefCast::ref_cast(info.key);
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
?created_at,
?expires_at,
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
}

104
proxy/src/cache/timed_lru/tests.rs vendored Normal file
View File

@@ -0,0 +1,104 @@
use super::*;
/// Check that we can define the cache for certain types.
#[test]
fn definition() {
// Check for trivial yet possible types.
let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&());
// Now for something less trivial.
let cache = TimedLru::<String, String>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&String::default());
let _ = cache.get("str should work");
// It should also work for non-cloneable values.
struct NoClone;
let cache = TimedLru::<Box<str>, NoClone>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), NoClone.into());
let _ = cache.get(&Box::<str>::from("boxed str"));
let _ = cache.get("str should work");
}
#[test]
fn insert() {
const CAPACITY: usize = 2;
let cache = TimedLru::<String, u32>::new("test", CAPACITY, Duration::from_secs(10));
assert_eq!(cache.size(), 0);
let key = Arc::new(String::from("key"));
let (old, cached) = cache.insert(key.clone(), 42.into());
assert_eq!(old, None);
assert_eq!(*cached, 42);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(key, 1.into());
assert_eq!(old.as_deref(), Some(&42));
assert_eq!(*cached, 1);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into());
assert_eq!(old, None);
assert_eq!(*cached, 10);
assert_eq!(cache.size(), 2);
let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into());
assert_eq!(old, None);
assert_eq!(*cached, 20);
assert_eq!(cache.size(), 2);
}
#[test]
fn get_none() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let cached = cache.get("missing");
assert!(matches!(cached, None));
}
#[test]
fn invalidation_simple() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let (_, cached) = cache.insert(String::from("key").into(), 100.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 0);
assert!(matches!(cache.get("key"), None));
}
#[test]
fn invalidation_preserve_newer() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let key = Arc::new(String::from("key"));
let (_, cached) = cache.insert(key.clone(), 100.into());
assert_eq!(cache.size(), 1);
let _ = cache.insert(key.clone(), 200.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 1);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&200));
}
#[test]
fn auto_expiry() {
let lifetime = Duration::from_millis(300);
let cache = TimedLru::<String, u32>::new("test", 2, lifetime);
let key = Arc::new(String::from("key"));
let _ = cache.insert(key.clone(), 42.into());
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&42));
std::thread::sleep(lifetime);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), None);
}

View File

@@ -1,13 +1,28 @@
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError};
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::messages::{DatabaseInfo, MetricsAuxInfo},
console::CachedNodeInfo,
error::UserFacingError,
};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration};
use std::{
io,
net::SocketAddr,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tracing::{error, info, warn};
/// Should we allow self-signed certificates in TLS connections?
/// Most definitely, this shouldn't be allowed in production.
pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false);
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)]
@@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
#[derive(Clone)]
pub enum Password {
/// A regular cleartext password.
ClearText(Vec<u8>),
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
ScramKeys(ScramKeys),
}
pub enum ComputeNode {
/// Route via link auth.
Link(DatabaseInfo),
/// Regular compute node.
Static {
password: Password,
info: CachedNodeInfo,
},
}
impl ComputeNode {
/// Get metrics auxiliary info.
pub fn metrics_aux_info(&self) -> &MetricsAuxInfo {
match self {
Self::Link(info) => &info.aux,
Self::Static { info, .. } => &info.aux,
}
}
/// Invalidate compute node info if it's cached.
pub fn invalidate(&self) -> bool {
if let Self::Static { info, .. } = self {
warn!("invalidating compute node info cache entry");
info.invalidate();
return true;
}
false
}
/// Turn compute node info into a postgres connection config.
pub fn to_conn_config(&self) -> ConnCfg {
let mut config = ConnCfg::new();
let (host, port) = match self {
Self::Link(info) => {
// NB: use pre-supplied dbname, user and password for link auth.
// See `ConnCfg::set_startup_params` below.
config.0.dbname(&info.dbname).user(&info.user);
if let Some(password) = &info.password {
config.0.password(password.as_bytes());
}
(&info.host, info.port)
}
Self::Static { info, password } => {
// NB: setup auth keys (for SCRAM) or plaintext password.
match password {
Password::ClearText(text) => config.0.password(text),
Password::ScramKeys(keys) => {
use tokio_postgres::config::AuthKeys;
config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned()))
}
};
(&info.address.host, info.address.port)
}
};
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
config.0.ssl_mode(if host.contains("--") {
// We need TLS connection with SNI info to properly route it.
tokio_postgres::config::SslMode::Require
} else {
tokio_postgres::config::SslMode::Disable
});
config.0.host(host).port(port);
config
}
}
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
@@ -51,43 +148,32 @@ pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub fn new() -> Self {
fn new() -> Self {
Self(Default::default())
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
/// Apply startup message params to the connection config.
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
self.user(user);
if let (None, Some(user)) = (self.0.get_user(), params.get("user")) {
self.0.user(user);
}
// Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
self.dbname(dbname);
if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) {
self.0.dbname(dbname);
}
// Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(params) {
self.options(&options);
self.0.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.application_name(app_name);
self.0.application_name(app_name);
}
// TODO: This is especially ugly...
@@ -95,10 +181,10 @@ impl ConnCfg {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical);
self.0.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.replication_mode(ReplicationMode::Logical);
self.0.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
@@ -113,27 +199,6 @@ impl ConnCfg {
}
}
impl std::ops::Deref for ConnCfg {
type Target = tokio_postgres::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> {
@@ -220,16 +285,13 @@ pub struct PostgresConnection {
}
impl ConnCfg {
async fn do_connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw().await?;
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
.danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed))
.build()?;
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
@@ -261,11 +323,8 @@ impl ConnCfg {
}
/// Connect to a corresponding compute node.
pub async fn connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
self.do_connect(allow_self_signed_compute)
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
self.do_connect()
.inspect_err(|err| {
// Immediately log the error we have at our disposal.
error!("couldn't connect to compute node: {err}");

View File

@@ -12,7 +12,6 @@ pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
}
#[derive(Debug)]
@@ -211,8 +210,11 @@ pub struct CacheOptions {
}
impl CacheOptions {
/// Default options for [`crate::auth::caches::AuthInfoCache`].
pub const DEFAULT_AUTH_INFO: &str = "size=4000,ttl=5s";
/// Default options for [`crate::auth::caches::NodeInfoCache`].
pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=5m";
pub const DEFAULT_NODE_INFO: &str = "size=4000,ttl=5m";
/// Parse cache options passed via cmdline.
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].

View File

@@ -6,12 +6,17 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
pub use provider::{errors, Api, ConsoleReqExtra};
pub use provider::{AuthInfo, NodeInfo};
pub use provider::{CachedAuthInfo, CachedNodeInfo};
/// Various cache-related types.
pub mod caches {
pub use super::provider::{ApiCaches, NodeInfoCache};
pub use super::provider::{ApiCaches, AuthInfoCache, AuthInfoCacheKey, NodeInfoCache};
}
/// Console's management API.
pub mod mgmt;
/// Console's notification bus.
pub mod notifications;

View File

@@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret {
}
}
/// Represents compute node's host & port pair.
#[derive(Debug)]
pub struct PgEndpoint {
pub host: Box<str>,
pub port: u16,
}
impl<'de> Deserialize<'de> for PgEndpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let raw = String::deserialize(deserializer)?;
let (host, port) = raw
.rsplit_once(':')
.ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?;
Ok(PgEndpoint {
host: host.into(),
port: port.parse().map_err(Error::custom)?,
})
}
}
/// Response which holds compute node's `host:port` pair.
/// Returned by the `/proxy_wake_compute` API method.
#[derive(Debug, Deserialize)]
pub struct WakeCompute {
pub address: Box<str>,
pub address: PgEndpoint,
pub aux: MetricsAuxInfo,
}
@@ -187,4 +213,24 @@ mod tests {
Ok(())
}
#[test]
fn parse_wake_compute() -> anyhow::Result<()> {
let _: WakeCompute = serde_json::from_value(json!({
"address": "127.0.0.1:5432",
"aux": dummy_aux(),
}))?;
let _: WakeCompute = serde_json::from_value(json!({
"address": "[::1]:5432",
"aux": dummy_aux(),
}))?;
serde_json::from_value::<WakeCompute>(json!({
"address": "localhost:5432",
"aux": dummy_aux(),
}))?;
Ok(())
}
}

View File

@@ -0,0 +1,70 @@
use crate::console::caches::{ApiCaches, AuthInfoCacheKey};
use futures::StreamExt;
use serde::Deserialize;
const CHANNEL_NAME: &str = "proxy_notifications";
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum Notification<'a> {
#[serde(rename = "password_changed")]
PasswordChanged { project: &'a str, role: &'a str },
}
#[tracing::instrument(skip(caches))]
fn handle_message(msg: redis::Msg, caches: &ApiCaches) -> anyhow::Result<()> {
let payload: String = msg.get_payload()?;
use Notification::*;
match serde_json::from_str(&payload) {
Ok(PasswordChanged { project, role }) => {
let key = AuthInfoCacheKey {
project: project.into(),
role: role.into(),
};
tracing::info!(key = ?key, "invalidating auth info");
caches.auth_info.remove(&key);
}
Err(e) => tracing::error!("broken message: {e}"),
}
Ok(())
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main(url: String, caches: &ApiCaches) -> anyhow::Result<()> {
let client = redis::Client::open(url.as_ref())?;
let mut conn = client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
conn.subscribe(CHANNEL_NAME).await?;
let mut stream = conn.on_message();
while let Some(msg) = stream.next().await {
handle_message(msg, caches)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_notification() -> anyhow::Result<()> {
let text = json!({
"type": "password_changed",
"project": "very-nice",
"role": "borat",
})
.to_string();
let _: Notification = serde_json::from_str(&text)?;
Ok(())
}
}

View File

@@ -1,14 +1,13 @@
pub mod mock;
pub mod neon;
use super::messages::MetricsAuxInfo;
use super::messages::{MetricsAuxInfo, PgEndpoint};
use crate::{
auth::ClientCredentials,
cache::{timed_lru, TimedLru},
compute, scram,
cache::{types::Cached, TimedLru},
scram,
};
use async_trait::async_trait;
use std::sync::Arc;
pub mod errors {
use crate::{
@@ -112,11 +111,9 @@ pub mod errors {
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(Box<str>),
#[error(transparent)]
ApiError(ApiError),
}
@@ -132,10 +129,7 @@ pub mod errors {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
// API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
@@ -160,23 +154,26 @@ pub enum AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
/// This struct is cached, so we shouldn't store any user creds here.
pub struct NodeInfo {
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub config: compute::ConnCfg,
/// Address of the compute node.
pub address: PgEndpoint,
/// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,
pub aux: MetricsAuxInfo,
}
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type NodeInfoCache = TimedLru<Box<str>, NodeInfo>;
pub type CachedNodeInfo = Cached<NodeInfo>;
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct AuthInfoCacheKey {
pub project: Box<str>,
pub role: Box<str>,
}
pub type AuthInfoCache = TimedLru<AuthInfoCacheKey, AuthInfo>;
pub type CachedAuthInfo = Cached<AuthInfo>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
@@ -187,7 +184,7 @@ pub trait Api {
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
) -> Result<Option<CachedAuthInfo>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
@@ -199,6 +196,8 @@ pub trait Api {
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
/// Cache for the `get_auth_info` method.
pub auth_info: AuthInfoCache,
/// Cache for the `wake_compute` method.
pub node_info: NodeInfoCache,
}

View File

@@ -2,13 +2,14 @@
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
ConsoleReqExtra, PgEndpoint,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use super::{AuthInfo, NodeInfo};
use super::{CachedAuthInfo, CachedNodeInfo};
use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Debug, Error)]
@@ -84,19 +85,15 @@ impl Api {
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(SslMode::Disable);
let host = self.endpoint.host_str().unwrap_or("localhost").into();
let port = self.endpoint.port().unwrap_or(5432);
let node = NodeInfo {
config,
let info = NodeInfo {
address: PgEndpoint { host, port },
aux: Default::default(),
allow_self_signed_compute: false,
};
Ok(node)
Ok(info)
}
}
@@ -107,8 +104,9 @@ impl super::Api for Api {
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(creds).await
) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
let res = self.do_get_auth_info(creds).await?;
Ok(res.map(CachedAuthInfo::new_uncached))
}
#[tracing::instrument(skip_all)]
@@ -117,9 +115,8 @@ impl super::Api for Api {
_extra: &ConsoleReqExtra<'_>,
_creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)
.await
let res = self.do_wake_compute().await?;
Ok(CachedNodeInfo::new_uncached(res))
}
}

View File

@@ -3,18 +3,19 @@
use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
ApiCaches, ConsoleReqExtra,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use super::{AuthInfo, AuthInfoCacheKey, CachedAuthInfo};
use super::{CachedNodeInfo, NodeInfo};
use crate::{auth::ClientCredentials, http, scram};
use async_trait::async_trait;
use futures::TryFutureExt;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
caches: &'static ApiCaches,
pub endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
}
impl Api {
@@ -91,22 +92,9 @@ impl Api {
let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo {
config,
aux: body.aux.into(),
allow_self_signed_compute: false,
address: body.address,
aux: body.aux,
};
Ok(node)
@@ -124,8 +112,28 @@ impl super::Api for Api {
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await
) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
let key = AuthInfoCacheKey {
project: creds.project().expect("impossible").into(),
role: creds.user.into(),
};
// Check if we already have a cached auth info for this project + user combo.
// Beware! We shouldn't flush this for unsuccessful auth attempts, otherwise
// the cache makes no sense whatsoever in the presence of unfaithful clients.
// Instead, we snoop an invalidation queue to keep the cache up-to-date.
if let Some(cached) = self.caches.auth_info.get(&key) {
info!(key = ?key, "found cached auth info");
return Ok(Some(cached));
}
let info = self.do_get_auth_info(extra, creds).await?;
Ok(info.map(|info| {
info!(key = ?key, "creating a cache entry for auth info");
let (_, cached) = self.caches.auth_info.insert(key.into(), info.into());
cached
}))
}
#[tracing::instrument(skip_all)]
@@ -134,20 +142,21 @@ impl super::Api for Api {
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key = creds.project().expect("impossible");
let key: Box<str> = creds.project().expect("impossible").into();
// Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) {
info!(key = key, "found cached compute node info");
if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = ?key, "found cached compute node info");
return Ok(cached);
}
let node = self.do_wake_compute(extra, creds).await?;
let (_, cached) = self.caches.node_info.insert(key.into(), node);
info!(key = key, "created a cache entry for compute node info");
let info = self.do_wake_compute(extra, creds).await?;
info!(key = ?key, "creating a cache entry for compute node info");
let (_, cached) = self.caches.node_info.insert(key.into(), info.into());
Ok(cached)
}
@@ -177,20 +186,3 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
}

View File

@@ -175,14 +175,9 @@ pub async fn handle(
application_name: Some(APP_NAME),
};
let node = creds.wake_compute(&extra).await?.expect("msg");
let conf = node.value.config;
let port = *conf.get_ports().first().expect("no port");
let host = match conf.get_hosts().first().expect("no host") {
tokio_postgres::config::Host::Tcp(host) => host,
tokio_postgres::config::Host::Unix(_) => {
return Err(anyhow::anyhow!("unix socket is not supported"));
}
};
let pg_endpoint = &node.value.address;
let port = pg_endpoint.port;
let host = &pg_endpoint.host;
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,

View File

@@ -4,7 +4,7 @@ mod tests;
use crate::{
auth::{self, backend::AuthSuccess},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
compute::{self, ComputeNode, PostgresConnection},
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id, false);
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
@@ -194,15 +194,7 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let allow_self_signed_compute = config.allow_self_signed_compute;
let client = Client::new(
stream,
creds,
&params,
session_id,
allow_self_signed_compute,
);
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
@@ -283,61 +275,38 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", skip_all)]
async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
) -> Result<PostgresConnection, compute::ConnectionError> {
// If we couldn't connect, a cached connection info might be to blame
// (e.g. the compute node's address might've changed at the wrong time).
// Invalidate the cache entry (if any) to prevent subsequent errors.
let invalidate_cache = |_: &compute::ConnectionError| {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
node_info.invalidate();
}
let label = match is_cached {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
};
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(allow_self_signed_compute)
.inspect_err(invalidate_cache)
.await
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
async fn connect_to_compute(
node_info: &mut console::CachedNodeInfo,
node_info: &mut ComputeNode,
params: &StartupMessageParams,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
let mut num_retries = NUM_RETRIES_WAKE_COMPUTE;
loop {
// Apply startup params to the (possibly, cached) compute node info.
node_info.config.set_startup_params(params);
match connect_to_compute_once(node_info).await {
let mut config = node_info.to_conn_config();
config.set_startup_params(params);
match config.connect().await {
Err(e) if num_retries > 0 => {
info!("compute node's state has changed; requesting a wake-up");
match creds.wake_compute(extra).map_err(io_error).await? {
let label = match node_info.invalidate() {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
let res = creds.wake_compute(extra).map_err(io_error).await?;
match (res, &node_info) {
// Update `node_info` and try one more time.
Some(mut new) => {
new.config.reuse_password(&node_info.config);
*node_info = new;
(Some(new), ComputeNode::Static { password, .. }) => {
*node_info = ComputeNode::Static {
password: password.to_owned(),
info: new,
}
}
// Link auth doesn't work that way, so we just exit.
None => return Err(e),
_ => return Err(e),
}
}
other => return other,
@@ -430,8 +399,6 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
}
impl<'a, S> Client<'a, S> {
@@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
allow_self_signed_compute: bool,
) -> Self {
Self {
stream,
creds,
params,
session_id,
allow_self_signed_compute,
}
}
}
@@ -468,7 +433,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
mut creds,
params,
session_id,
allow_self_signed_compute,
} = self;
let extra = console::ConsoleReqExtra {
@@ -491,19 +455,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
value: mut node_info,
} = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(stream, node.stream, &node_info.aux).await
proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await
}
}

View File

@@ -1845,13 +1845,15 @@ class VanillaPostgres(PgProtocol):
]
)
def configure(self, options: List[str]):
def configure(self, options: List[str]) -> "VanillaPostgres":
"""Append lines into postgresql.conf file."""
assert not self.running
with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file:
conf_file.write("\n".join(options))
def start(self, log_path: Optional[str] = None):
return self
def start(self, log_path: Optional[str] = None) -> "VanillaPostgres":
assert not self.running
self.running = True
@@ -1862,11 +1864,15 @@ class VanillaPostgres(PgProtocol):
["pg_ctl", "-w", "-D", str(self.pgdatadir), "-l", log_path, "start"]
)
def stop(self):
return self
def stop(self) -> "VanillaPostgres":
assert self.running
self.running = False
self.pg_bin.run_capture(["pg_ctl", "-w", "-D", str(self.pgdatadir), "stop"])
return self
def get_subdir_size(self, subdir) -> int:
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(os.path.join(self.pgdatadir, subdir))
@@ -2035,6 +2041,17 @@ class NeonProxy(PgProtocol):
*["--auth-endpoint", self.pg_conn_url],
]
@dataclass(frozen=True)
class Console(AuthBackend):
console_url: str
def extra_args(self) -> list[str]:
return [
# Postgres auth backend params
*["--auth-backend", "console"],
*["--auth-endpoint", self.console_url],
]
def __init__(
self,
neon_binpath: Path,
@@ -2240,6 +2257,33 @@ def link_proxy(
yield proxy
@pytest.fixture(scope="function")
def console_proxy(
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth."""
http_port = port_distributor.get_port()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
console_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
console_url = f"http://127.0.0.1:{console_port}"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Console(console_url),
) as proxy:
proxy.start()
yield proxy
@pytest.fixture(scope="function")
def static_proxy(
vanilla_pg: VanillaPostgres,

View File

@@ -1,11 +1,13 @@
import json
import logging
import subprocess
from typing import Any, List
from typing import Any, List, cast
import psycopg2
import pytest
import requests
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
from aiohttp import web
from fixtures.neon_fixtures import PSQL, NeonProxy, PortDistributor, VanillaPostgres
def test_proxy_select_1(static_proxy: NeonProxy):
@@ -225,3 +227,78 @@ def test_sql_over_http(static_proxy: NeonProxy):
res = q("drop table t")
assert res["command"] == "DROP"
assert res["rowCount"] is None
@pytest.mark.asyncio
async def test_compute_cache_invalidation(
port_distributor: PortDistributor, vanilla_pg: VanillaPostgres, console_proxy: NeonProxy
):
console_url = cast(NeonProxy.Console, console_proxy.auth_backend).console_url
logging.info(f"mocked console's url is {console_url}")
console_port = int(console_url.split(":")[-1])
logging.info(f"mocked console's port is {console_port}")
routes = web.RouteTableDef()
@routes.get("/proxy_get_role_secret")
async def get_role_secret(request):
# corresponds to password "password"
secret = ":".join(
[
"SCRAM-SHA-256$4096",
"t33UQcz/cs1D+n9INqThsw==$1NYlCbuxtK7YF2sgECBDTv1Myf8PpHJCT3RgKSXlZL0=",
"9iLeGY91MqBQ4ez1389Smo7h+STsJJ5jvu7kNofxj08=",
]
)
return web.json_response({"role_secret": secret})
wake_compute_called = 0
postgres_port = vanilla_pg.default_options["port"]
@routes.get("/proxy_wake_compute")
async def wake_compute(request):
nonlocal wake_compute_called
wake_compute_called += 1
nonlocal postgres_port
logging.info(f"compute's port is {postgres_port}")
return web.json_response(
{
"address": f"127.0.0.1:{postgres_port}",
"aux": {
"endpoint_id": "",
"project_id": "",
"branch_id": "",
},
}
)
console = web.Application()
console.add_routes(routes)
runner = web.AppRunner(console)
await runner.setup()
await web.TCPSite(runner, "127.0.0.1", console_port).start()
# Create a user we're going to use in the test sequence
user, password = "borat", "password"
vanilla_pg.start().safe_psql(f"create role {user} with login password '{password}'")
async def try_connect():
await console_proxy.connect_async(user=user, password=password, dbname="postgres")
assert wake_compute_called == 0
# Try connecting to compute
await try_connect()
assert wake_compute_called == 1
# Change compute's port
postgres_port = port_distributor.get_port()
vanilla_pg.stop().configure([f"port = {postgres_port}"]).start()
# Try connecting to compute
await try_connect()
assert wake_compute_called == 2

View File

@@ -42,7 +42,8 @@ regex = { version = "1" }
regex-syntax = { version = "0.6" }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "multipart", "rustls-tls"] }
ring = { version = "0.16", features = ["std"] }
rustls = { version = "0.20", features = ["dangerous_configuration"] }
rustls-56bd22fc3884b12 = { package = "rustls", version = "0.20", features = ["dangerous_configuration"] }
rustls-647d43efb71741da = { package = "rustls", version = "0.21" }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }