mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 02:42:56 +00:00
Compare commits
14 Commits
statement_
...
funbringer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f49af45a7 | ||
|
|
c7872d3f6d | ||
|
|
c82b00110e | ||
|
|
d425f2cb14 | ||
|
|
2b0066f67c | ||
|
|
01be823084 | ||
|
|
69bb1eee18 | ||
|
|
0c040aca63 | ||
|
|
d6c7b4d994 | ||
|
|
251e410add | ||
|
|
ce416c8160 | ||
|
|
f7e9ec49be | ||
|
|
2b60ad0285 | ||
|
|
9b99d4caa9 |
66
Cargo.lock
generated
66
Cargo.lock
generated
@@ -905,6 +905,20 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "comfy-table"
|
||||
version = "6.1.4"
|
||||
@@ -3104,6 +3118,8 @@ dependencies = [
|
||||
"prometheus",
|
||||
"rand",
|
||||
"rcgen",
|
||||
"redis",
|
||||
"ref-cast",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
@@ -3210,6 +3226,29 @@ dependencies = [
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ea8c51b5dc1d8e5fd3350ec8167f464ec0995e79f2e90a075b63371500d557f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"combine",
|
||||
"futures-util",
|
||||
"itoa",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls 0.21.0",
|
||||
"rustls-native-certs",
|
||||
"ryu",
|
||||
"sha1_smol",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
@@ -3228,6 +3267,26 @@ dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c"
|
||||
dependencies = [
|
||||
"ref-cast-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast-impl"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.7.3"
|
||||
@@ -3844,6 +3903,12 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1_smol"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
@@ -5326,6 +5391,7 @@ dependencies = [
|
||||
"reqwest",
|
||||
"ring",
|
||||
"rustls 0.20.8",
|
||||
"rustls 0.21.0",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
||||
@@ -77,10 +77,12 @@ pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
ref-cast = "1.0"
|
||||
redis = { version = "0.23.0", features = ["tokio-rustls-comp"] }
|
||||
regex = "1.4"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
|
||||
reqwest-middleware = "0.2.0"
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
|
||||
routerify = "3"
|
||||
rpds = "0.13"
|
||||
rustls = "0.20"
|
||||
|
||||
@@ -35,7 +35,9 @@ postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
ref-cast.workspace = true
|
||||
regex.workspace = true
|
||||
redis.workspace = true
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
reqwest-middleware.workspace = true
|
||||
reqwest-tracing.workspace = true
|
||||
@@ -50,9 +52,9 @@ socket2.workspace = true
|
||||
sync_wrapper.workspace = true
|
||||
thiserror.workspace = true
|
||||
tls-listener.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
|
||||
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
compute::ComputeNode,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
@@ -114,7 +115,7 @@ async fn auth_quirks(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
allow_cleartext: bool,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
allow_cleartext: bool,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
@@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
link::authenticate(url, client)
|
||||
.await?
|
||||
.map(CachedNodeInfo::new_uncached)
|
||||
link::authenticate(url, client).await?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
|
||||
compute::{self, ComputeNode, Password},
|
||||
console::{self, AuthInfo, CachedAuthInfo, ConsoleReqExtra},
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
};
|
||||
@@ -14,25 +14,26 @@ pub(super) async fn authenticate(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
info!("fetching user's authentication info");
|
||||
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
|
||||
let info = scram::ServerSecret::mock(creds.user, rand::random());
|
||||
CachedAuthInfo::new_uncached(AuthInfo::Scram(info))
|
||||
});
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match info {
|
||||
let keys = match &*info {
|
||||
AuthInfo::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret);
|
||||
let scram = auth::Scram(secret);
|
||||
let client_key = match flow.begin(scram).await?.authenticate().await? {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
@@ -41,21 +42,20 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
Some(compute::ScramKeys {
|
||||
compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
if let Some(keys) = scram_keys {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
node.config.auth_keys(AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ScramKeys(keys),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
},
|
||||
compute::{ComputeNode, Password},
|
||||
console::{self, provider::ConsoleReqExtra},
|
||||
stream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -19,7 +17,7 @@ pub async fn cleartext_hack(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
let password = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword)
|
||||
@@ -27,13 +25,15 @@ pub async fn cleartext_hack(
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
node.config.password(password);
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ClearText(password),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ pub async fn password_hack(
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -55,12 +55,14 @@ pub async fn password_hack(
|
||||
info!(project = &payload.endpoint, "received missing parameter");
|
||||
creds.project = Some(payload.endpoint);
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
node.config.password(payload.password);
|
||||
let info = api.wake_compute(extra, creds).await?;
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
value: ComputeNode::Static {
|
||||
password: Password::ClearText(payload.password),
|
||||
info,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters,
|
||||
};
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String {
|
||||
pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
) -> auth::Result<AuthSuccess<ComputeNode>> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let span = info_span!("link", psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
let info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
@@ -79,35 +74,8 @@ pub(super) async fn authenticate(
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
config.ssl_mode(SslMode::Require);
|
||||
} else {
|
||||
config.ssl_mode(SslMode::Disable);
|
||||
}
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: true,
|
||||
value: NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux.into(),
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
},
|
||||
value: ComputeNode::Link(info),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use proxy::auth;
|
||||
use proxy::console;
|
||||
use proxy::http;
|
||||
use proxy::metrics;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{self, Arg};
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use futures::future::try_join_all;
|
||||
use proxy::{
|
||||
auth,
|
||||
config::{self, MetricCollectionConfig, ProxyConfig, TlsConfig},
|
||||
console, http, metrics,
|
||||
};
|
||||
use std::{borrow::Cow, net::SocketAddr, sync::atomic::Ordering};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use tracing::{info, warn};
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
@@ -25,8 +24,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
::metrics::set_build_info_metric(GIT_VERSION);
|
||||
|
||||
let args = cli().get_matches();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(build_config(&args)?));
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
@@ -71,20 +69,29 @@ async fn main() -> anyhow::Result<()> {
|
||||
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
|
||||
}
|
||||
|
||||
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err));
|
||||
let client_tasks =
|
||||
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
if let Some(url) = args.get_one::<String>("redis-notifications") {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
tasks.push(tokio::spawn(console::notifications::task_main(
|
||||
url.to_owned(),
|
||||
api.caches,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let tasks = try_join_all(tasks.into_iter().map(proxy::flatten_err));
|
||||
let client_tasks = try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
|
||||
tokio::select! {
|
||||
// We are only expecting an error from these forever tasks
|
||||
res = tasks => { res?; },
|
||||
res = client_tasks => { res?; },
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let tls_config = match (
|
||||
fn build_tls_config(args: &clap::ArgMatches) -> anyhow::Result<Option<TlsConfig>> {
|
||||
let config = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -101,16 +108,22 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
.get_one::<String>("allow-self-signed-compute")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
|
||||
if allow_self_signed_compute {
|
||||
warn!("allowing self-signed compute certificates");
|
||||
proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let metric_collection = match (
|
||||
args.get_one::<String>("metric-collection-endpoint"),
|
||||
args.get_one::<String>("metric-collection-interval"),
|
||||
) {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn build_metrics_config(args: &clap::ArgMatches) -> anyhow::Result<Option<MetricCollectionConfig>> {
|
||||
let endpoint = args.get_one::<String>("metric-collection-endpoint");
|
||||
let interval = args.get_one::<String>("metric-collection-interval");
|
||||
|
||||
let config = match (endpoint, interval) {
|
||||
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
|
||||
endpoint: endpoint.parse()?,
|
||||
endpoint: endpoint.parse().context("bad metrics endpoint")?,
|
||||
interval: humantime::parse_duration(interval)?,
|
||||
}),
|
||||
(None, None) => None,
|
||||
@@ -120,21 +133,40 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
),
|
||||
};
|
||||
|
||||
let auth_backend = match args.get_one::<String>("auth-backend").unwrap().as_str() {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn make_caches(args: &clap::ArgMatches) -> anyhow::Result<console::caches::ApiCaches> {
|
||||
let config::CacheOptions { size, ttl } = args
|
||||
.get_one::<String>("get-auth-info-cache")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
|
||||
info!("Using AuthInfoCache (get_auth_info) with size={size} ttl={ttl:?}");
|
||||
let auth_info = console::caches::AuthInfoCache::new("auth_info_cache", size, ttl);
|
||||
|
||||
let config::CacheOptions { size, ttl } = args
|
||||
.get_one::<String>("wake-compute-cache")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
|
||||
let node_info = console::caches::NodeInfoCache::new("node_info_cache", size, ttl);
|
||||
|
||||
let caches = console::caches::ApiCaches {
|
||||
auth_info,
|
||||
node_info,
|
||||
};
|
||||
|
||||
Ok(caches)
|
||||
}
|
||||
|
||||
fn build_auth_config(args: &clap::ArgMatches) -> anyhow::Result<auth::BackendType<'static, ()>> {
|
||||
let config = match args.get_one::<String>("auth-backend").unwrap().as_str() {
|
||||
"console" => {
|
||||
let config::CacheOptions { size, ttl } = args
|
||||
.get_one::<String>("wake-compute-cache")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches {
|
||||
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
|
||||
}));
|
||||
|
||||
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let caches = Box::leak(Box::new(make_caches(args)?));
|
||||
let api = console::provider::neon::Api::new(endpoint, caches);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
@@ -150,12 +182,16 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
other => bail!("unsupported auth backend: {other}"),
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute,
|
||||
}));
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<ProxyConfig> {
|
||||
let config = ProxyConfig {
|
||||
tls_config: build_tls_config(args)?,
|
||||
auth_backend: build_auth_config(args)?,
|
||||
metric_collection: build_metrics_config(args)?,
|
||||
};
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
@@ -239,11 +275,22 @@ fn cli() -> clap::Command {
|
||||
.long("metric-collection-interval")
|
||||
.help("how often metrics should be sent to a collection endpoint"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("redis-notifications")
|
||||
.long("redis-notifications")
|
||||
.help("for receiving notifications from console (e.g. redis://127.0.0.1:6379)"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("get-auth-info-cache")
|
||||
.long("get-auth-info-cache")
|
||||
.help("cache for `get_auth_info` api method (use `size=0` to disable)")
|
||||
.default_value(config::CacheOptions::DEFAULT_AUTH_INFO),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("wake-compute-cache")
|
||||
.long("wake-compute-cache")
|
||||
.help("cache for `wake_compute` api method (use `size=0` to disable)")
|
||||
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO),
|
||||
.default_value(config::CacheOptions::DEFAULT_NODE_INFO),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("allow-self-signed-compute")
|
||||
|
||||
@@ -1,304 +1,117 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
use std::{any::Any, sync::Arc, time::Instant};
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
/// A variant of LRU where every entry has a TTL.
|
||||
pub mod timed_lru;
|
||||
pub use timed_lru::TimedLru;
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`timed_lru::Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
/// Useful type aliases.
|
||||
pub mod types {
|
||||
pub type Cached<T> = super::Cached<'static, T>;
|
||||
}
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
/// Lookup information for cache entry invalidation.
|
||||
#[derive(Clone)]
|
||||
pub struct LookupInfo<K> {
|
||||
/// Cache entry creation time.
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// This type incapsulates everything needed for cache entry invalidation.
|
||||
/// For convenience, we completely erase the types of a cache ref and a key.
|
||||
/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec).
|
||||
#[derive(Clone)]
|
||||
pub struct InvalidationToken<'a> {
|
||||
// TODO: allow more than one type of references (e.g. Arc) if it's ever needed.
|
||||
cache: &'a (dyn Cache + Sync + Send),
|
||||
info: LookupInfo<Arc<dyn Any + Sync + Send>>,
|
||||
}
|
||||
|
||||
impl InvalidationToken<'_> {
|
||||
/// Invalidate a corresponding cache entry.
|
||||
pub fn invalidate(&self) {
|
||||
let info = LookupInfo {
|
||||
created_at: self.info.created_at,
|
||||
key: self.info.key.as_ref(),
|
||||
};
|
||||
self.cache.invalidate_entry(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// A combination of a cache entry and its invalidation token.
|
||||
/// Makes it easier to see how those two are connected.
|
||||
#[derive(Clone)]
|
||||
pub struct Cached<'a, T> {
|
||||
pub token: Option<InvalidationToken<'a>>,
|
||||
pub value: Arc<T>,
|
||||
}
|
||||
|
||||
impl<T> Cached<'_, T> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: T) -> Self {
|
||||
Self {
|
||||
token: None,
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalidate a corresponding cache entry.
|
||||
pub fn invalidate(&self) {
|
||||
if let Some(token) = &self.token {
|
||||
token.invalidate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Cached<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait captures the notion of cache entry invalidation.
|
||||
/// It doesn't have any associated types because we use dyn-based type erasure.
|
||||
trait Cache {
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
pub use timed_lru::TimedLru;
|
||||
pub mod timed_lru {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
#[test]
|
||||
fn trivial_properties_of_cached() {
|
||||
let cached = Cached::new_uncached(0);
|
||||
assert_eq!(*cached, 0);
|
||||
cached.invalidate();
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
#[test]
|
||||
fn invalidation_token_type_erasure() {
|
||||
let lifetime = std::time::Duration::from_secs(10);
|
||||
let foo = TimedLru::<u32, u32>::new("foo", 128, lifetime);
|
||||
let bar = TimedLru::<String, usize>::new("bar", 128, lifetime);
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
let (_, x) = foo.insert(100.into(), 0.into());
|
||||
let (_, y) = bar.insert(String::new().into(), 404.into());
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
}
|
||||
// Invalidation tokens should be cloneable and homogeneous (same type).
|
||||
let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()];
|
||||
for token in tokens {
|
||||
token.invalidate();
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache> {
|
||||
/// Cache + lookup info.
|
||||
token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: C::Value,
|
||||
}
|
||||
|
||||
impl<C: Cache> Cached<C> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
|
||||
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
|
||||
Self {
|
||||
token: None,
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(&self) {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> Deref for Cached<C> {
|
||||
type Target = C::Value;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> DerefMut for Cached<C> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
// Values are still there.
|
||||
assert_eq!(*x, 0);
|
||||
assert_eq!(*y, 404);
|
||||
}
|
||||
}
|
||||
|
||||
293
proxy/src/cache/timed_lru.rs
vendored
Normal file
293
proxy/src/cache/timed_lru.rs
vendored
Normal file
@@ -0,0 +1,293 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Cache, Cached, InvalidationToken, LookupInfo};
|
||||
use ref_cast::RefCast;
|
||||
use std::{
|
||||
any::Any,
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<Key<K>, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
#[derive(RefCast, Hash, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
struct Query<Q: ?Sized>(Q);
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
struct Key<T: ?Sized>(Arc<T>);
|
||||
|
||||
/// It's impossible to implement this without [`Key`] & [`Query`]:
|
||||
/// * We can't implement std traits for [`Arc`].
|
||||
/// * Even if we could, it'd conflict with `impl<T> Borrow<T> for T`.
|
||||
impl<Q, T> Borrow<Query<Q>> for Key<T>
|
||||
where
|
||||
Q: ?Sized,
|
||||
T: Borrow<Q>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn borrow(&self) -> &Query<Q> {
|
||||
RefCast::ref_cast(self.0.as_ref().borrow())
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: Arc<T>,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of entries in the cache.
|
||||
/// Note that this method will not try to evict stale entries.
|
||||
pub fn size(&self) -> usize {
|
||||
self.cache.lock().len()
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, F, R>(&self, key: &Q, extract: F) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
F: FnOnce(&Arc<K>, &Entry<V>) -> R,
|
||||
{
|
||||
let key: &Query<Q> = RefCast::ref_cast(key);
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(&raw_entry.key().0, entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
?created_at,
|
||||
old_expires_at = ?expires_at,
|
||||
new_expires_at = ?deadline,
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Instant) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(Key(key), entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
?created_at,
|
||||
?expires_at,
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(old, created_at)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn remove_raw<Q>(&self, key: &Q)
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let key: &Query<Q> = RefCast::ref_cast(key);
|
||||
let mut cache = self.cache.lock();
|
||||
cache.remove(key);
|
||||
|
||||
debug!("removed a cache entry");
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenient wrappers for raw methods.
|
||||
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
|
||||
where
|
||||
Self: Sync + Send,
|
||||
{
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<Cached<V>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: Arc::clone(key),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some(Self::invalidation_token(self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Cached<V>) {
|
||||
let (old, created_at) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let info = LookupInfo { created_at, key };
|
||||
let cached = Cached {
|
||||
token: Some(Self::invalidation_token(self, info)),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
|
||||
pub fn remove<Q>(&self, key: &Q)
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.remove_raw(key)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation details of the entry invalidation machinery.
|
||||
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
|
||||
where
|
||||
Self: Sync + Send,
|
||||
{
|
||||
/// This is a proper (safe) way to create an invalidation token for [`TimedLru`].
|
||||
fn invalidation_token(&self, info: LookupInfo<Arc<K>>) -> InvalidationToken<'_> {
|
||||
InvalidationToken {
|
||||
cache: self,
|
||||
info: LookupInfo {
|
||||
created_at: info.created_at,
|
||||
key: info.key,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This implementation depends on [`Self::invalidation_token`].
|
||||
impl<K: Hash + Eq + 'static, V> Cache for TimedLru<K, V> {
|
||||
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) {
|
||||
let info = LookupInfo {
|
||||
created_at: info.created_at,
|
||||
// NOTE: it's important to downcast to the correct type!
|
||||
key: info.key.downcast_ref::<K>().expect("bad key type"),
|
||||
};
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: LookupInfo<&K>) {
|
||||
let key: &Query<K> = RefCast::ref_cast(info.key);
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
?created_at,
|
||||
?expires_at,
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
}
|
||||
104
proxy/src/cache/timed_lru/tests.rs
vendored
Normal file
104
proxy/src/cache/timed_lru/tests.rs
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
use super::*;
|
||||
|
||||
/// Check that we can define the cache for certain types.
|
||||
#[test]
|
||||
fn definition() {
|
||||
// Check for trivial yet possible types.
|
||||
let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), Default::default());
|
||||
let _ = cache.get(&());
|
||||
|
||||
// Now for something less trivial.
|
||||
let cache = TimedLru::<String, String>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), Default::default());
|
||||
let _ = cache.get(&String::default());
|
||||
let _ = cache.get("str should work");
|
||||
|
||||
// It should also work for non-cloneable values.
|
||||
struct NoClone;
|
||||
let cache = TimedLru::<Box<str>, NoClone>::new("test", 128, Duration::from_secs(0));
|
||||
let _ = cache.insert(Default::default(), NoClone.into());
|
||||
let _ = cache.get(&Box::<str>::from("boxed str"));
|
||||
let _ = cache.get("str should work");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insert() {
|
||||
const CAPACITY: usize = 2;
|
||||
let cache = TimedLru::<String, u32>::new("test", CAPACITY, Duration::from_secs(10));
|
||||
assert_eq!(cache.size(), 0);
|
||||
|
||||
let key = Arc::new(String::from("key"));
|
||||
|
||||
let (old, cached) = cache.insert(key.clone(), 42.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 42);
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let (old, cached) = cache.insert(key, 1.into());
|
||||
assert_eq!(old.as_deref(), Some(&42));
|
||||
assert_eq!(*cached, 1);
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 10);
|
||||
assert_eq!(cache.size(), 2);
|
||||
|
||||
let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into());
|
||||
assert_eq!(old, None);
|
||||
assert_eq!(*cached, 20);
|
||||
assert_eq!(cache.size(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_none() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let cached = cache.get("missing");
|
||||
assert!(matches!(cached, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidation_simple() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let (_, cached) = cache.insert(String::from("key").into(), 100.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
cached.invalidate();
|
||||
|
||||
assert_eq!(cache.size(), 0);
|
||||
assert!(matches!(cache.get("key"), None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidation_preserve_newer() {
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
|
||||
let key = Arc::new(String::from("key"));
|
||||
|
||||
let (_, cached) = cache.insert(key.clone(), 100.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
let _ = cache.insert(key.clone(), 200.into());
|
||||
assert_eq!(cache.size(), 1);
|
||||
cached.invalidate();
|
||||
assert_eq!(cache.size(), 1);
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), Some(&200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_expiry() {
|
||||
let lifetime = Duration::from_millis(300);
|
||||
let cache = TimedLru::<String, u32>::new("test", 2, lifetime);
|
||||
|
||||
let key = Arc::new(String::from("key"));
|
||||
let _ = cache.insert(key.clone(), 42.into());
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), Some(&42));
|
||||
|
||||
std::thread::sleep(lifetime);
|
||||
|
||||
let cached = cache.get(key.as_ref());
|
||||
assert_eq!(cached.as_deref(), None);
|
||||
}
|
||||
@@ -1,13 +1,28 @@
|
||||
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError};
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::messages::{DatabaseInfo, MetricsAuxInfo},
|
||||
console::CachedNodeInfo,
|
||||
error::UserFacingError,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{io, net::SocketAddr, time::Duration};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
/// Should we allow self-signed certificates in TLS connections?
|
||||
/// Most definitely, this shouldn't be allowed in production.
|
||||
pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError {
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Password {
|
||||
/// A regular cleartext password.
|
||||
ClearText(Vec<u8>),
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
ScramKeys(ScramKeys),
|
||||
}
|
||||
|
||||
pub enum ComputeNode {
|
||||
/// Route via link auth.
|
||||
Link(DatabaseInfo),
|
||||
/// Regular compute node.
|
||||
Static {
|
||||
password: Password,
|
||||
info: CachedNodeInfo,
|
||||
},
|
||||
}
|
||||
|
||||
impl ComputeNode {
|
||||
/// Get metrics auxiliary info.
|
||||
pub fn metrics_aux_info(&self) -> &MetricsAuxInfo {
|
||||
match self {
|
||||
Self::Link(info) => &info.aux,
|
||||
Self::Static { info, .. } => &info.aux,
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalidate compute node info if it's cached.
|
||||
pub fn invalidate(&self) -> bool {
|
||||
if let Self::Static { info, .. } = self {
|
||||
warn!("invalidating compute node info cache entry");
|
||||
info.invalidate();
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Turn compute node info into a postgres connection config.
|
||||
pub fn to_conn_config(&self) -> ConnCfg {
|
||||
let mut config = ConnCfg::new();
|
||||
|
||||
let (host, port) = match self {
|
||||
Self::Link(info) => {
|
||||
// NB: use pre-supplied dbname, user and password for link auth.
|
||||
// See `ConnCfg::set_startup_params` below.
|
||||
config.0.dbname(&info.dbname).user(&info.user);
|
||||
if let Some(password) = &info.password {
|
||||
config.0.password(password.as_bytes());
|
||||
}
|
||||
|
||||
(&info.host, info.port)
|
||||
}
|
||||
Self::Static { info, password } => {
|
||||
// NB: setup auth keys (for SCRAM) or plaintext password.
|
||||
match password {
|
||||
Password::ClearText(text) => config.0.password(text),
|
||||
Password::ScramKeys(keys) => {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned()))
|
||||
}
|
||||
};
|
||||
|
||||
(&info.address.host, info.address.port)
|
||||
}
|
||||
};
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
config.0.ssl_mode(if host.contains("--") {
|
||||
// We need TLS connection with SNI info to properly route it.
|
||||
tokio_postgres::config::SslMode::Require
|
||||
} else {
|
||||
tokio_postgres::config::SslMode::Disable
|
||||
});
|
||||
|
||||
config.0.host(host).port(port);
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
@@ -51,43 +148,32 @@ pub struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub fn new() -> Self {
|
||||
fn new() -> Self {
|
||||
Self(Default::default())
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Link auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
self.user(user);
|
||||
if let (None, Some(user)) = (self.0.get_user(), params.get("user")) {
|
||||
self.0.user(user);
|
||||
}
|
||||
|
||||
// Only set `dbname` if it's not present in the config.
|
||||
// Link auth flow takes dbname from the console's response.
|
||||
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
|
||||
self.dbname(dbname);
|
||||
if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) {
|
||||
self.0.dbname(dbname);
|
||||
}
|
||||
|
||||
// Don't add `options` if they were only used for specifying a project.
|
||||
// Connection pools don't support `options`, because they affect backend startup.
|
||||
if let Some(options) = filtered_options(params) {
|
||||
self.options(&options);
|
||||
self.0.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.application_name(app_name);
|
||||
self.0.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
@@ -95,10 +181,10 @@ impl ConnCfg {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
self.0.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.replication_mode(ReplicationMode::Logical);
|
||||
self.0.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
@@ -113,27 +199,6 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = tokio_postgres::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnCfg {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
@@ -220,16 +285,13 @@ pub struct PostgresConnection {
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
async fn do_connect(
|
||||
&self,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw().await?;
|
||||
|
||||
let tls_connector = native_tls::TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(allow_self_signed_compute)
|
||||
.build()
|
||||
.unwrap();
|
||||
.danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed))
|
||||
.build()?;
|
||||
|
||||
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
|
||||
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
|
||||
|
||||
@@ -261,11 +323,8 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
self.do_connect(allow_self_signed_compute)
|
||||
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
|
||||
self.do_connect()
|
||||
.inspect_err(|err| {
|
||||
// Immediately log the error we have at our disposal.
|
||||
error!("couldn't connect to compute node: {err}");
|
||||
|
||||
@@ -12,7 +12,6 @@ pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -211,8 +210,11 @@ pub struct CacheOptions {
|
||||
}
|
||||
|
||||
impl CacheOptions {
|
||||
/// Default options for [`crate::auth::caches::AuthInfoCache`].
|
||||
pub const DEFAULT_AUTH_INFO: &str = "size=4000,ttl=5s";
|
||||
|
||||
/// Default options for [`crate::auth::caches::NodeInfoCache`].
|
||||
pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=5m";
|
||||
pub const DEFAULT_NODE_INFO: &str = "size=4000,ttl=5m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
|
||||
|
||||
@@ -6,12 +6,17 @@ pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
|
||||
pub use provider::{errors, Api, ConsoleReqExtra};
|
||||
pub use provider::{AuthInfo, NodeInfo};
|
||||
pub use provider::{CachedAuthInfo, CachedNodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
pub use super::provider::{ApiCaches, NodeInfoCache};
|
||||
pub use super::provider::{ApiCaches, AuthInfoCache, AuthInfoCacheKey, NodeInfoCache};
|
||||
}
|
||||
|
||||
/// Console's management API.
|
||||
pub mod mgmt;
|
||||
|
||||
/// Console's notification bus.
|
||||
pub mod notifications;
|
||||
|
||||
@@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents compute node's host & port pair.
|
||||
#[derive(Debug)]
|
||||
pub struct PgEndpoint {
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PgEndpoint {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
let raw = String::deserialize(deserializer)?;
|
||||
let (host, port) = raw
|
||||
.rsplit_once(':')
|
||||
.ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?;
|
||||
|
||||
Ok(PgEndpoint {
|
||||
host: host.into(),
|
||||
port: port.parse().map_err(Error::custom)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
/// Returned by the `/proxy_wake_compute` API method.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WakeCompute {
|
||||
pub address: Box<str>,
|
||||
pub address: PgEndpoint,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
@@ -187,4 +213,24 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wake_compute() -> anyhow::Result<()> {
|
||||
let _: WakeCompute = serde_json::from_value(json!({
|
||||
"address": "127.0.0.1:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
let _: WakeCompute = serde_json::from_value(json!({
|
||||
"address": "[::1]:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
serde_json::from_value::<WakeCompute>(json!({
|
||||
"address": "localhost:5432",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
70
proxy/src/console/notifications.rs
Normal file
70
proxy/src/console/notifications.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use crate::console::caches::{ApiCaches, AuthInfoCacheKey};
|
||||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
|
||||
const CHANNEL_NAME: &str = "proxy_notifications";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum Notification<'a> {
|
||||
#[serde(rename = "password_changed")]
|
||||
PasswordChanged { project: &'a str, role: &'a str },
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(caches))]
|
||||
fn handle_message(msg: redis::Msg, caches: &ApiCaches) -> anyhow::Result<()> {
|
||||
let payload: String = msg.get_payload()?;
|
||||
|
||||
use Notification::*;
|
||||
match serde_json::from_str(&payload) {
|
||||
Ok(PasswordChanged { project, role }) => {
|
||||
let key = AuthInfoCacheKey {
|
||||
project: project.into(),
|
||||
role: role.into(),
|
||||
};
|
||||
|
||||
tracing::info!(key = ?key, "invalidating auth info");
|
||||
caches.auth_info.remove(&key);
|
||||
}
|
||||
Err(e) => tracing::error!("broken message: {e}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main(url: String, caches: &ApiCaches) -> anyhow::Result<()> {
|
||||
let client = redis::Client::open(url.as_ref())?;
|
||||
let mut conn = client.get_async_connection().await?.into_pubsub();
|
||||
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
let mut stream = conn.on_message();
|
||||
|
||||
while let Some(msg) = stream.next().await {
|
||||
handle_message(msg, caches)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_notification() -> anyhow::Result<()> {
|
||||
let text = json!({
|
||||
"type": "password_changed",
|
||||
"project": "very-nice",
|
||||
"role": "borat",
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let _: Notification = serde_json::from_str(&text)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use super::messages::{MetricsAuxInfo, PgEndpoint};
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
cache::{types::Cached, TimedLru},
|
||||
scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
@@ -112,11 +111,9 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
@@ -132,10 +129,7 @@ pub mod errors {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
// API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
@@ -160,23 +154,26 @@ pub enum AuthInfo {
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
/// This struct is cached, so we shouldn't store any user creds here.
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
/// Address of the compute node.
|
||||
pub address: PgEndpoint,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: Arc<MetricsAuxInfo>,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
pub type NodeInfoCache = TimedLru<Box<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<NodeInfo>;
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq)]
|
||||
pub struct AuthInfoCacheKey {
|
||||
pub project: Box<str>,
|
||||
pub role: Box<str>,
|
||||
}
|
||||
|
||||
pub type AuthInfoCache = TimedLru<AuthInfoCacheKey, AuthInfo>;
|
||||
pub type CachedAuthInfo = Cached<AuthInfo>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
@@ -187,7 +184,7 @@ pub trait Api {
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
|
||||
) -> Result<Option<CachedAuthInfo>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
@@ -199,6 +196,8 @@ pub trait Api {
|
||||
|
||||
/// Various caches for [`console`].
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
/// Cache for the `get_auth_info` method.
|
||||
pub auth_info: AuthInfoCache,
|
||||
/// Cache for the `wake_compute` method.
|
||||
pub node_info: NodeInfoCache,
|
||||
}
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
ConsoleReqExtra, PgEndpoint,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use super::{AuthInfo, NodeInfo};
|
||||
use super::{CachedAuthInfo, CachedNodeInfo};
|
||||
use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -84,19 +85,15 @@ impl Api {
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
let host = self.endpoint.host_str().unwrap_or("localhost").into();
|
||||
let port = self.endpoint.port().unwrap_or(5432);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
let info = NodeInfo {
|
||||
address: PgEndpoint { host, port },
|
||||
aux: Default::default(),
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
Ok(info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,8 +104,9 @@ impl super::Api for Api {
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(creds).await
|
||||
) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
|
||||
let res = self.do_get_auth_info(creds).await?;
|
||||
Ok(res.map(CachedAuthInfo::new_uncached))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -117,9 +115,8 @@ impl super::Api for Api {
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
_creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
let res = self.do_wake_compute().await?;
|
||||
Ok(CachedNodeInfo::new_uncached(res))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,18 +3,19 @@
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
ApiCaches, ConsoleReqExtra,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use super::{AuthInfo, AuthInfoCacheKey, CachedAuthInfo};
|
||||
use super::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth::ClientCredentials, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
pub endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
@@ -91,22 +92,9 @@ impl Api {
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux.into(),
|
||||
allow_self_signed_compute: false,
|
||||
address: body.address,
|
||||
aux: body.aux,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
@@ -124,8 +112,28 @@ impl super::Api for Api {
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(extra, creds).await
|
||||
) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
|
||||
let key = AuthInfoCacheKey {
|
||||
project: creds.project().expect("impossible").into(),
|
||||
role: creds.user.into(),
|
||||
};
|
||||
|
||||
// Check if we already have a cached auth info for this project + user combo.
|
||||
// Beware! We shouldn't flush this for unsuccessful auth attempts, otherwise
|
||||
// the cache makes no sense whatsoever in the presence of unfaithful clients.
|
||||
// Instead, we snoop an invalidation queue to keep the cache up-to-date.
|
||||
if let Some(cached) = self.caches.auth_info.get(&key) {
|
||||
info!(key = ?key, "found cached auth info");
|
||||
return Ok(Some(cached));
|
||||
}
|
||||
|
||||
let info = self.do_get_auth_info(extra, creds).await?;
|
||||
|
||||
Ok(info.map(|info| {
|
||||
info!(key = ?key, "creating a cache entry for auth info");
|
||||
let (_, cached) = self.caches.auth_info.insert(key.into(), info.into());
|
||||
cached
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -134,20 +142,21 @@ impl super::Api for Api {
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = creds.project().expect("impossible");
|
||||
let key: Box<str> = creds.project().expect("impossible").into();
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = ?key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), node);
|
||||
info!(key = key, "created a cache entry for compute node info");
|
||||
let info = self.do_wake_compute(extra, creds).await?;
|
||||
|
||||
info!(key = ?key, "creating a cache entry for compute node info");
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), info.into());
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
@@ -177,20 +186,3 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
error!("console responded with an error ({status}): {text}");
|
||||
Err(ApiError::Console { status, text })
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,14 +175,9 @@ pub async fn handle(
|
||||
application_name: Some(APP_NAME),
|
||||
};
|
||||
let node = creds.wake_compute(&extra).await?.expect("msg");
|
||||
let conf = node.value.config;
|
||||
let port = *conf.get_ports().first().expect("no port");
|
||||
let host = match conf.get_hosts().first().expect("no host") {
|
||||
tokio_postgres::config::Host::Tcp(host) => host,
|
||||
tokio_postgres::config::Host::Unix(_) => {
|
||||
return Err(anyhow::anyhow!("unix socket is not supported"));
|
||||
}
|
||||
};
|
||||
let pg_endpoint = &node.value.address;
|
||||
let port = pg_endpoint.port;
|
||||
let host = &pg_endpoint.host;
|
||||
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
|
||||
@@ -4,7 +4,7 @@ mod tests;
|
||||
use crate::{
|
||||
auth::{self, backend::AuthSuccess},
|
||||
cancellation::{self, CancelMap},
|
||||
compute::{self, PostgresConnection},
|
||||
compute::{self, ComputeNode, PostgresConnection},
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
error::io_error,
|
||||
@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id, false);
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, true))
|
||||
.await
|
||||
@@ -194,15 +194,7 @@ async fn handle_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let allow_self_signed_compute = config.allow_self_signed_compute;
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
);
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, false))
|
||||
.await
|
||||
@@ -283,61 +275,38 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
// If we couldn't connect, a cached connection info might be to blame
|
||||
// (e.g. the compute node's address might've changed at the wrong time).
|
||||
// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
let invalidate_cache = |_: &compute::ConnectionError| {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
node_info.invalidate();
|
||||
}
|
||||
|
||||
let label = match is_cached {
|
||||
true => "compute_cached",
|
||||
false => "compute_uncached",
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
};
|
||||
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect(allow_self_signed_compute)
|
||||
.inspect_err(invalidate_cache)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn connect_to_compute(
|
||||
node_info: &mut console::CachedNodeInfo,
|
||||
node_info: &mut ComputeNode,
|
||||
params: &StartupMessageParams,
|
||||
extra: &console::ConsoleReqExtra<'_>,
|
||||
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
|
||||
let mut num_retries = NUM_RETRIES_WAKE_COMPUTE;
|
||||
loop {
|
||||
// Apply startup params to the (possibly, cached) compute node info.
|
||||
node_info.config.set_startup_params(params);
|
||||
match connect_to_compute_once(node_info).await {
|
||||
let mut config = node_info.to_conn_config();
|
||||
config.set_startup_params(params);
|
||||
match config.connect().await {
|
||||
Err(e) if num_retries > 0 => {
|
||||
info!("compute node's state has changed; requesting a wake-up");
|
||||
match creds.wake_compute(extra).map_err(io_error).await? {
|
||||
let label = match node_info.invalidate() {
|
||||
true => "compute_cached",
|
||||
false => "compute_uncached",
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
|
||||
let res = creds.wake_compute(extra).map_err(io_error).await?;
|
||||
match (res, &node_info) {
|
||||
// Update `node_info` and try one more time.
|
||||
Some(mut new) => {
|
||||
new.config.reuse_password(&node_info.config);
|
||||
*node_info = new;
|
||||
(Some(new), ComputeNode::Static { password, .. }) => {
|
||||
*node_info = ComputeNode::Static {
|
||||
password: password.to_owned(),
|
||||
info: new,
|
||||
}
|
||||
}
|
||||
// Link auth doesn't work that way, so we just exit.
|
||||
None => return Err(e),
|
||||
_ => return Err(e),
|
||||
}
|
||||
}
|
||||
other => return other,
|
||||
@@ -430,8 +399,6 @@ struct Client<'a, S> {
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
@@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> {
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -468,7 +433,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
mut creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
} = self;
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
@@ -491,19 +455,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
value: mut node_info,
|
||||
} = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(stream, node.stream, &node_info.aux).await
|
||||
|
||||
proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1845,13 +1845,15 @@ class VanillaPostgres(PgProtocol):
|
||||
]
|
||||
)
|
||||
|
||||
def configure(self, options: List[str]):
|
||||
def configure(self, options: List[str]) -> "VanillaPostgres":
|
||||
"""Append lines into postgresql.conf file."""
|
||||
assert not self.running
|
||||
with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file:
|
||||
conf_file.write("\n".join(options))
|
||||
|
||||
def start(self, log_path: Optional[str] = None):
|
||||
return self
|
||||
|
||||
def start(self, log_path: Optional[str] = None) -> "VanillaPostgres":
|
||||
assert not self.running
|
||||
self.running = True
|
||||
|
||||
@@ -1862,11 +1864,15 @@ class VanillaPostgres(PgProtocol):
|
||||
["pg_ctl", "-w", "-D", str(self.pgdatadir), "-l", log_path, "start"]
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
return self
|
||||
|
||||
def stop(self) -> "VanillaPostgres":
|
||||
assert self.running
|
||||
self.running = False
|
||||
self.pg_bin.run_capture(["pg_ctl", "-w", "-D", str(self.pgdatadir), "stop"])
|
||||
|
||||
return self
|
||||
|
||||
def get_subdir_size(self, subdir) -> int:
|
||||
"""Return size of pgdatadir subdirectory in bytes."""
|
||||
return get_dir_size(os.path.join(self.pgdatadir, subdir))
|
||||
@@ -2035,6 +2041,17 @@ class NeonProxy(PgProtocol):
|
||||
*["--auth-endpoint", self.pg_conn_url],
|
||||
]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Console(AuthBackend):
|
||||
console_url: str
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
return [
|
||||
# Postgres auth backend params
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.console_url],
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neon_binpath: Path,
|
||||
@@ -2240,6 +2257,33 @@ def link_proxy(
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def console_proxy(
|
||||
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
|
||||
) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes through link auth."""
|
||||
|
||||
http_port = port_distributor.get_port()
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
console_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
|
||||
console_url = f"http://127.0.0.1:{console_port}"
|
||||
|
||||
with NeonProxy(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.Console(console_url),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def static_proxy(
|
||||
vanilla_pg: VanillaPostgres,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from typing import Any, List
|
||||
from typing import Any, List, cast
|
||||
|
||||
import psycopg2
|
||||
import pytest
|
||||
import requests
|
||||
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
|
||||
from aiohttp import web
|
||||
from fixtures.neon_fixtures import PSQL, NeonProxy, PortDistributor, VanillaPostgres
|
||||
|
||||
|
||||
def test_proxy_select_1(static_proxy: NeonProxy):
|
||||
@@ -225,3 +227,78 @@ def test_sql_over_http(static_proxy: NeonProxy):
|
||||
res = q("drop table t")
|
||||
assert res["command"] == "DROP"
|
||||
assert res["rowCount"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compute_cache_invalidation(
|
||||
port_distributor: PortDistributor, vanilla_pg: VanillaPostgres, console_proxy: NeonProxy
|
||||
):
|
||||
console_url = cast(NeonProxy.Console, console_proxy.auth_backend).console_url
|
||||
logging.info(f"mocked console's url is {console_url}")
|
||||
console_port = int(console_url.split(":")[-1])
|
||||
logging.info(f"mocked console's port is {console_port}")
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@routes.get("/proxy_get_role_secret")
|
||||
async def get_role_secret(request):
|
||||
# corresponds to password "password"
|
||||
secret = ":".join(
|
||||
[
|
||||
"SCRAM-SHA-256$4096",
|
||||
"t33UQcz/cs1D+n9INqThsw==$1NYlCbuxtK7YF2sgECBDTv1Myf8PpHJCT3RgKSXlZL0=",
|
||||
"9iLeGY91MqBQ4ez1389Smo7h+STsJJ5jvu7kNofxj08=",
|
||||
]
|
||||
)
|
||||
|
||||
return web.json_response({"role_secret": secret})
|
||||
|
||||
wake_compute_called = 0
|
||||
postgres_port = vanilla_pg.default_options["port"]
|
||||
|
||||
@routes.get("/proxy_wake_compute")
|
||||
async def wake_compute(request):
|
||||
nonlocal wake_compute_called
|
||||
wake_compute_called += 1
|
||||
|
||||
nonlocal postgres_port
|
||||
logging.info(f"compute's port is {postgres_port}")
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"address": f"127.0.0.1:{postgres_port}",
|
||||
"aux": {
|
||||
"endpoint_id": "",
|
||||
"project_id": "",
|
||||
"branch_id": "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
console = web.Application()
|
||||
console.add_routes(routes)
|
||||
|
||||
runner = web.AppRunner(console)
|
||||
await runner.setup()
|
||||
await web.TCPSite(runner, "127.0.0.1", console_port).start()
|
||||
|
||||
# Create a user we're going to use in the test sequence
|
||||
user, password = "borat", "password"
|
||||
vanilla_pg.start().safe_psql(f"create role {user} with login password '{password}'")
|
||||
|
||||
async def try_connect():
|
||||
await console_proxy.connect_async(user=user, password=password, dbname="postgres")
|
||||
|
||||
assert wake_compute_called == 0
|
||||
|
||||
# Try connecting to compute
|
||||
await try_connect()
|
||||
assert wake_compute_called == 1
|
||||
|
||||
# Change compute's port
|
||||
postgres_port = port_distributor.get_port()
|
||||
vanilla_pg.stop().configure([f"port = {postgres_port}"]).start()
|
||||
|
||||
# Try connecting to compute
|
||||
await try_connect()
|
||||
assert wake_compute_called == 2
|
||||
|
||||
@@ -42,7 +42,8 @@ regex = { version = "1" }
|
||||
regex-syntax = { version = "0.6" }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "multipart", "rustls-tls"] }
|
||||
ring = { version = "0.16", features = ["std"] }
|
||||
rustls = { version = "0.20", features = ["dangerous_configuration"] }
|
||||
rustls-56bd22fc3884b12 = { package = "rustls", version = "0.20", features = ["dangerous_configuration"] }
|
||||
rustls-647d43efb71741da = { package = "rustls", version = "0.21" }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
|
||||
Reference in New Issue
Block a user