Compare commits

...

14 Commits

Author SHA1 Message Date
Vadim Kharitonov
8f49af45a7 Try to fix test 2023-05-24 15:39:39 +02:00
Vadim Kharitonov
c7872d3f6d Make clippy happy 2023-05-24 14:18:32 +02:00
Vadim Kharitonov
c82b00110e Fix SQL over HTTP endpoint 2023-05-24 14:05:03 +02:00
Vadim Kharitonov
d425f2cb14 Fix ruff after rebase 2023-05-24 13:56:57 +02:00
Dmitry Ivanov
2b0066f67c Minor edits 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
01be823084 Tune console notifications handler 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
69bb1eee18 Update workspace_hack 2023-05-24 12:54:27 +02:00
Dmitry Ivanov
0c040aca63 Fix mypy warnings 2023-05-24 12:54:25 +02:00
Dmitry Ivanov
d6c7b4d994 Fix formatting 2023-05-24 12:53:23 +02:00
Dmitry Ivanov
251e410add Implement console notifications listener 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
ce416c8160 Implement a cache for get_auth_info method 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
f7e9ec49be Add stub for auth info cache 2023-05-24 12:51:04 +02:00
Dmitry Ivanov
2b60ad0285 Test compute node cache invalidation 2023-05-24 12:51:00 +02:00
Dmitry Ivanov
9b99d4caa9 New cache impl 2023-05-24 12:45:57 +02:00
24 changed files with 1166 additions and 619 deletions

66
Cargo.lock generated
View File

@@ -905,6 +905,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "combine"
version = "4.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]] [[package]]
name = "comfy-table" name = "comfy-table"
version = "6.1.4" version = "6.1.4"
@@ -3104,6 +3118,8 @@ dependencies = [
"prometheus", "prometheus",
"rand", "rand",
"rcgen", "rcgen",
"redis",
"ref-cast",
"regex", "regex",
"reqwest", "reqwest",
"reqwest-middleware", "reqwest-middleware",
@@ -3210,6 +3226,29 @@ dependencies = [
"yasna", "yasna",
] ]
[[package]]
name = "redis"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ea8c51b5dc1d8e5fd3350ec8167f464ec0995e79f2e90a075b63371500d557f"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"rustls 0.21.0",
"rustls-native-certs",
"ryu",
"sha1_smol",
"tokio",
"tokio-rustls 0.24.0",
"tokio-util",
"url",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.2.16" version = "0.2.16"
@@ -3228,6 +3267,26 @@ dependencies = [
"bitflags", "bitflags",
] ]
[[package]]
name = "ref-cast"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43faa91b1c8b36841ee70e97188a869d37ae21759da6846d4be66de5bf7b12c"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d2275aab483050ab2a7364c1a46604865ee7d6906684e08db0f090acf74f9e7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.7.3" version = "1.7.3"
@@ -3844,6 +3903,12 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.6" version = "0.10.6"
@@ -5326,6 +5391,7 @@ dependencies = [
"reqwest", "reqwest",
"ring", "ring",
"rustls 0.20.8", "rustls 0.20.8",
"rustls 0.21.0",
"scopeguard", "scopeguard",
"serde", "serde",
"serde_json", "serde_json",

View File

@@ -77,10 +77,12 @@ pin-project-lite = "0.2"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11" prost = "0.11"
rand = "0.8" rand = "0.8"
ref-cast = "1.0"
redis = { version = "0.23.0", features = ["tokio-rustls-comp"] }
regex = "1.4" regex = "1.4"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
reqwest-middleware = "0.2.0" reqwest-middleware = "0.2.0"
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_18"] }
routerify = "3" routerify = "3"
rpds = "0.13" rpds = "0.13"
rustls = "0.20" rustls = "0.20"

View File

@@ -35,7 +35,9 @@ postgres_backend.workspace = true
pq_proto.workspace = true pq_proto.workspace = true
prometheus.workspace = true prometheus.workspace = true
rand.workspace = true rand.workspace = true
ref-cast.workspace = true
regex.workspace = true regex.workspace = true
redis.workspace = true
reqwest = { workspace = true, features = ["json"] } reqwest = { workspace = true, features = ["json"] }
reqwest-middleware.workspace = true reqwest-middleware.workspace = true
reqwest-tracing.workspace = true reqwest-tracing.workspace = true
@@ -50,9 +52,9 @@ socket2.workspace = true
sync_wrapper.workspace = true sync_wrapper.workspace = true
thiserror.workspace = true thiserror.workspace = true
tls-listener.workspace = true tls-listener.workspace = true
tokio = { workspace = true, features = ["signal"] }
tokio-postgres.workspace = true tokio-postgres.workspace = true
tokio-rustls.workspace = true tokio-rustls.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true tracing-subscriber.workspace = true
tracing-utils.workspace = true tracing-utils.workspace = true

View File

@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
use crate::{ use crate::{
auth::{self, ClientCredentials}, auth::{self, ClientCredentials},
compute::ComputeNode,
console::{ console::{
self, self,
provider::{CachedNodeInfo, ConsoleReqExtra}, provider::{CachedNodeInfo, ConsoleReqExtra},
@@ -114,7 +115,7 @@ async fn auth_quirks(
creds: &mut ClientCredentials<'_>, creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool, allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
// If there's no project so far, that entails that client doesn't // If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name. // support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password. // We now expect to see a very specific payload in the place of password.
@@ -156,7 +157,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool, allow_cleartext: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
use BackendType::*; use BackendType::*;
let res = match self { let res = match self {
@@ -184,9 +185,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
Link(url) => { Link(url) => {
info!("performing link authentication"); info!("performing link authentication");
link::authenticate(url, client) link::authenticate(url, client).await?
.await?
.map(CachedNodeInfo::new_uncached)
} }
}; };

View File

@@ -1,8 +1,8 @@
use super::AuthSuccess; use super::AuthSuccess;
use crate::{ use crate::{
auth::{self, AuthFlow, ClientCredentials}, auth::{self, AuthFlow, ClientCredentials},
compute, compute::{self, ComputeNode, Password},
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra}, console::{self, AuthInfo, CachedAuthInfo, ConsoleReqExtra},
sasl, scram, sasl, scram,
stream::PqStream, stream::PqStream,
}; };
@@ -14,25 +14,26 @@ pub(super) async fn authenticate(
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>, creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
info!("fetching user's authentication info"); info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| { let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to // If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps). // prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication. // This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it"); info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) let info = scram::ServerSecret::mock(creds.user, rand::random());
CachedAuthInfo::new_uncached(AuthInfo::Scram(info))
}); });
let flow = AuthFlow::new(client); let flow = AuthFlow::new(client);
let scram_keys = match info { let keys = match &*info {
AuthInfo::Md5(_) => { AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5"); info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5")); return Err(auth::AuthError::bad_auth_method("MD5"));
} }
AuthInfo::Scram(secret) => { AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM"); info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret); let scram = auth::Scram(secret);
let client_key = match flow.begin(scram).await?.authenticate().await? { let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key, sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => { sasl::Outcome::Failure(reason) => {
@@ -41,21 +42,20 @@ pub(super) async fn authenticate(
} }
}; };
Some(compute::ScramKeys { compute::ScramKeys {
client_key: client_key.as_bytes(), client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(), server_key: secret.server_key.as_bytes(),
}) }
} }
}; };
let mut node = api.wake_compute(extra, creds).await?; let info = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
Ok(AuthSuccess { Ok(AuthSuccess {
reported_auth_ok: false, reported_auth_ok: false,
value: node, value: ComputeNode::Static {
password: Password::ScramKeys(keys),
info,
},
}) })
} }

View File

@@ -1,10 +1,8 @@
use super::AuthSuccess; use super::AuthSuccess;
use crate::{ use crate::{
auth::{self, AuthFlow, ClientCredentials}, auth::{self, AuthFlow, ClientCredentials},
console::{ compute::{ComputeNode, Password},
self, console::{self, provider::ConsoleReqExtra},
provider::{CachedNodeInfo, ConsoleReqExtra},
},
stream, stream,
}; };
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
@@ -19,7 +17,7 @@ pub async fn cleartext_hack(
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>, creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("cleartext auth flow override is enabled, proceeding"); warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client) let password = AuthFlow::new(client)
.begin(auth::CleartextPassword) .begin(auth::CleartextPassword)
@@ -27,13 +25,15 @@ pub async fn cleartext_hack(
.authenticate() .authenticate()
.await?; .await?;
let mut node = api.wake_compute(extra, creds).await?; let info = api.wake_compute(extra, creds).await?;
node.config.password(password);
// Report tentative success; compute node will check the password anyway. // Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess { Ok(AuthSuccess {
reported_auth_ok: false, reported_auth_ok: false,
value: node, value: ComputeNode::Static {
password: Password::ClearText(password),
info,
},
}) })
} }
@@ -44,7 +44,7 @@ pub async fn password_hack(
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>, creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
warn!("project not specified, resorting to the password hack auth flow"); warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client) let payload = AuthFlow::new(client)
.begin(auth::PasswordHack) .begin(auth::PasswordHack)
@@ -55,12 +55,14 @@ pub async fn password_hack(
info!(project = &payload.endpoint, "received missing parameter"); info!(project = &payload.endpoint, "received missing parameter");
creds.project = Some(payload.endpoint); creds.project = Some(payload.endpoint);
let mut node = api.wake_compute(extra, creds).await?; let info = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
// Report tentative success; compute node will check the password anyway. // Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess { Ok(AuthSuccess {
reported_auth_ok: false, reported_auth_ok: false,
value: node, value: ComputeNode::Static {
password: Password::ClearText(payload.password),
info,
},
}) })
} }

View File

@@ -1,15 +1,10 @@
use super::AuthSuccess; use super::AuthSuccess;
use crate::{ use crate::{
auth, compute, auth, compute::ComputeNode, console, error::UserFacingError, stream::PqStream, waiters,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
}; };
use pq_proto::BeMessage as Be; use pq_proto::BeMessage as Be;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span}; use tracing::{info, info_span};
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -57,12 +52,12 @@ pub fn new_psql_session_id() -> String {
pub(super) async fn authenticate( pub(super) async fn authenticate(
link_uri: &reqwest::Url, link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<NodeInfo>> { ) -> auth::Result<AuthSuccess<ComputeNode>> {
let psql_session_id = new_psql_session_id(); let psql_session_id = new_psql_session_id();
let span = info_span!("link", psql_session_id = &psql_session_id); let span = info_span!("link", psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id);
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async { let info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database. // Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user"); info!(parent: &span, "sending the auth URL to the user");
client client
@@ -79,35 +74,8 @@ pub(super) async fn authenticate(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok(AuthSuccess { Ok(AuthSuccess {
reported_auth_ok: true, reported_auth_ok: true,
value: NodeInfo { value: ComputeNode::Link(info),
config,
aux: db_info.aux.into(),
allow_self_signed_compute: false, // caller may override
},
}) })
} }

View File

@@ -1,16 +1,15 @@
use proxy::auth; use anyhow::{bail, Context};
use proxy::console;
use proxy::http;
use proxy::metrics;
use anyhow::bail;
use clap::{self, Arg}; use clap::{self, Arg};
use proxy::config::{self, ProxyConfig}; use futures::future::try_join_all;
use std::{borrow::Cow, net::SocketAddr}; use proxy::{
auth,
config::{self, MetricCollectionConfig, ProxyConfig, TlsConfig},
console, http, metrics,
};
use std::{borrow::Cow, net::SocketAddr, sync::atomic::Ordering};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::info; use tracing::{info, warn};
use tracing::warn;
use utils::{project_git_version, sentry_init::init_sentry}; use utils::{project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION); project_git_version!(GIT_VERSION);
@@ -25,8 +24,7 @@ async fn main() -> anyhow::Result<()> {
::metrics::set_build_info_metric(GIT_VERSION); ::metrics::set_build_info_metric(GIT_VERSION);
let args = cli().get_matches(); let args = cli().get_matches();
let config = build_config(&args)?; let config: &ProxyConfig = Box::leak(Box::new(build_config(&args)?));
info!("Authentication backend: {}", config.auth_backend); info!("Authentication backend: {}", config.auth_backend);
// Check that we can bind to address before further initialization // Check that we can bind to address before further initialization
@@ -71,20 +69,29 @@ async fn main() -> anyhow::Result<()> {
tasks.push(tokio::spawn(metrics::task_main(metrics_config))); tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
} }
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err)); if let auth::BackendType::Console(api, _) = &config.auth_backend {
let client_tasks = if let Some(url) = args.get_one::<String>("redis-notifications") {
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err)); info!("Starting redis notifications listener ({url})");
tasks.push(tokio::spawn(console::notifications::task_main(
url.to_owned(),
api.caches,
)));
}
}
let tasks = try_join_all(tasks.into_iter().map(proxy::flatten_err));
let client_tasks = try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
tokio::select! { tokio::select! {
// We are only expecting an error from these forever tasks // We are only expecting an error from these forever tasks
res = tasks => { res?; }, res = tasks => { res?; },
res = client_tasks => { res?; }, res = client_tasks => { res?; },
} }
Ok(()) Ok(())
} }
/// ProxyConfig is created at proxy startup, and lives forever. fn build_tls_config(args: &clap::ArgMatches) -> anyhow::Result<Option<TlsConfig>> {
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> { let config = match (
let tls_config = match (
args.get_one::<String>("tls-key"), args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"), args.get_one::<String>("tls-cert"),
) { ) {
@@ -101,16 +108,22 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
.get_one::<String>("allow-self-signed-compute") .get_one::<String>("allow-self-signed-compute")
.unwrap() .unwrap()
.parse()?; .parse()?;
if allow_self_signed_compute { if allow_self_signed_compute {
warn!("allowing self-signed compute certificates"); warn!("allowing self-signed compute certificates");
proxy::compute::ALLOW_SELF_SIGNED_COMPUTE.store(true, Ordering::Relaxed);
} }
let metric_collection = match ( Ok(config)
args.get_one::<String>("metric-collection-endpoint"), }
args.get_one::<String>("metric-collection-interval"),
) { fn build_metrics_config(args: &clap::ArgMatches) -> anyhow::Result<Option<MetricCollectionConfig>> {
let endpoint = args.get_one::<String>("metric-collection-endpoint");
let interval = args.get_one::<String>("metric-collection-interval");
let config = match (endpoint, interval) {
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig { (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
endpoint: endpoint.parse()?, endpoint: endpoint.parse().context("bad metrics endpoint")?,
interval: humantime::parse_duration(interval)?, interval: humantime::parse_duration(interval)?,
}), }),
(None, None) => None, (None, None) => None,
@@ -120,21 +133,40 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
), ),
}; };
let auth_backend = match args.get_one::<String>("auth-backend").unwrap().as_str() { Ok(config)
}
fn make_caches(args: &clap::ArgMatches) -> anyhow::Result<console::caches::ApiCaches> {
let config::CacheOptions { size, ttl } = args
.get_one::<String>("get-auth-info-cache")
.unwrap()
.parse()?;
info!("Using AuthInfoCache (get_auth_info) with size={size} ttl={ttl:?}");
let auth_info = console::caches::AuthInfoCache::new("auth_info_cache", size, ttl);
let config::CacheOptions { size, ttl } = args
.get_one::<String>("wake-compute-cache")
.unwrap()
.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
let node_info = console::caches::NodeInfoCache::new("node_info_cache", size, ttl);
let caches = console::caches::ApiCaches {
auth_info,
node_info,
};
Ok(caches)
}
fn build_auth_config(args: &clap::ArgMatches) -> anyhow::Result<auth::BackendType<'static, ()>> {
let config = match args.get_one::<String>("auth-backend").unwrap().as_str() {
"console" => { "console" => {
let config::CacheOptions { size, ttl } = args
.get_one::<String>("wake-compute-cache")
.unwrap()
.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches {
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
}));
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?; let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
let endpoint = http::Endpoint::new(url, http::new_client()); let endpoint = http::Endpoint::new(url, http::new_client());
let caches = Box::leak(Box::new(make_caches(args)?));
let api = console::provider::neon::Api::new(endpoint, caches); let api = console::provider::neon::Api::new(endpoint, caches);
auth::BackendType::Console(Cow::Owned(api), ()) auth::BackendType::Console(Cow::Owned(api), ())
} }
@@ -150,12 +182,16 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
other => bail!("unsupported auth backend: {other}"), other => bail!("unsupported auth backend: {other}"),
}; };
let config = Box::leak(Box::new(ProxyConfig { Ok(config)
tls_config, }
auth_backend,
metric_collection, /// ProxyConfig is created at proxy startup, and lives forever.
allow_self_signed_compute, fn build_config(args: &clap::ArgMatches) -> anyhow::Result<ProxyConfig> {
})); let config = ProxyConfig {
tls_config: build_tls_config(args)?,
auth_backend: build_auth_config(args)?,
metric_collection: build_metrics_config(args)?,
};
Ok(config) Ok(config)
} }
@@ -239,11 +275,22 @@ fn cli() -> clap::Command {
.long("metric-collection-interval") .long("metric-collection-interval")
.help("how often metrics should be sent to a collection endpoint"), .help("how often metrics should be sent to a collection endpoint"),
) )
.arg(
Arg::new("redis-notifications")
.long("redis-notifications")
.help("for receiving notifications from console (e.g. redis://127.0.0.1:6379)"),
)
.arg(
Arg::new("get-auth-info-cache")
.long("get-auth-info-cache")
.help("cache for `get_auth_info` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_AUTH_INFO),
)
.arg( .arg(
Arg::new("wake-compute-cache") Arg::new("wake-compute-cache")
.long("wake-compute-cache") .long("wake-compute-cache")
.help("cache for `wake_compute` api method (use `size=0` to disable)") .help("cache for `wake_compute` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO), .default_value(config::CacheOptions::DEFAULT_NODE_INFO),
) )
.arg( .arg(
Arg::new("allow-self-signed-compute") Arg::new("allow-self-signed-compute")

View File

@@ -1,304 +1,117 @@
use std::{ use std::{any::Any, sync::Arc, time::Instant};
borrow::Borrow,
hash::Hash,
ops::{Deref, DerefMut},
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`: /// A variant of LRU where every entry has a TTL.
// pub mod timed_lru;
// * `near/nearcore` ditched `cached` in favor of `lru` pub use timed_lru::TimedLru;
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// A generic trait which exposes types of cache's key and value, /// Useful type aliases.
/// as well as the notion of cache entry invalidation. pub mod types {
/// This is useful for [`timed_lru::Cached`]. pub type Cached<T> = super::Cached<'static, T>;
pub trait Cache { }
/// Entry's key.
type Key;
/// Entry's value. /// Lookup information for cache entry invalidation.
type Value; #[derive(Clone)]
pub struct LookupInfo<K> {
/// Cache entry creation time.
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Used for entry invalidation. /// Search by this key.
type LookupInfo<Key>; key: K,
}
/// This type incapsulates everything needed for cache entry invalidation.
/// For convenience, we completely erase the types of a cache ref and a key.
/// This lets us store multiple tokens in a homogeneous collection (e.g. Vec).
#[derive(Clone)]
pub struct InvalidationToken<'a> {
// TODO: allow more than one type of references (e.g. Arc) if it's ever needed.
cache: &'a (dyn Cache + Sync + Send),
info: LookupInfo<Arc<dyn Any + Sync + Send>>,
}
impl InvalidationToken<'_> {
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
let info = LookupInfo {
created_at: self.info.created_at,
key: self.info.key.as_ref(),
};
self.cache.invalidate_entry(info);
}
}
/// A combination of a cache entry and its invalidation token.
/// Makes it easier to see how those two are connected.
#[derive(Clone)]
pub struct Cached<'a, T> {
pub token: Option<InvalidationToken<'a>>,
pub value: Arc<T>,
}
impl<T> Cached<'_, T> {
/// Place any entry into this wrapper; invalidation will be a no-op.
pub fn new_uncached(value: T) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Invalidate a corresponding cache entry.
pub fn invalidate(&self) {
if let Some(token) = &self.token {
token.invalidate();
}
}
}
impl<T> std::ops::Deref for Cached<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
/// This trait captures the notion of cache entry invalidation.
/// It doesn't have any associated types because we use dyn-based type erasure.
trait Cache {
/// Invalidate an entry using a lookup info. /// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone. /// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>); fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Send + Sync)>);
} }
impl<C: Cache> Cache for &C { #[cfg(test)]
type Key = C::Key; mod tests {
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
C::invalidate(self, info)
}
}
pub use timed_lru::TimedLru;
pub mod timed_lru {
use super::*; use super::*;
/// An implementation of timed LRU cache with fixed capacity. #[test]
/// Key properties: fn trivial_properties_of_cached() {
/// let cached = Cached::new_uncached(0);
/// * Whenever a new entry is inserted, the least recently accessed one is evicted. assert_eq!(*cached, 0);
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). cached.invalidate();
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
} }
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> { #[test]
type Key = K; fn invalidation_token_type_erasure() {
type Value = V; let lifetime = std::time::Duration::from_secs(10);
type LookupInfo<Key> = LookupInfo<Key>; let foo = TimedLru::<u32, u32>::new("foo", 128, lifetime);
let bar = TimedLru::<String, usize>::new("bar", 128, lifetime);
fn invalidate(&self, info: &Self::LookupInfo<K>) { let (_, x) = foo.insert(100.into(), 0.into());
self.invalidate_raw(info) let (_, y) = bar.insert(String::new().into(), 404.into());
}
}
struct Entry<T> { // Invalidation tokens should be cloneable and homogeneous (same type).
created_at: Instant, let tokens = [x.token.clone().unwrap(), y.token.clone().unwrap()];
expires_at: Instant, for token in tokens {
value: T, token.invalidate();
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
} }
/// Drop an entry from the cache if it's outdated. // Values are still there.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] assert_eq!(*x, 0);
fn invalidate_raw(&self, info: &LookupInfo<K>) { assert_eq!(*y, 404);
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
value,
};
(old, cached)
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper.
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
})
}
}
/// Lookup information for key invalidation.
pub struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}
/// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache> {
/// Cache + lookup info.
token: Option<(C, C::LookupInfo<C::Key>)>,
/// The value itself.
pub value: C::Value,
}
impl<C: Cache> Cached<C> {
/// Place any entry into this wrapper; invalidation will be a no-op.
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Drop this entry from a cache if it's still there.
pub fn invalidate(&self) {
if let Some((cache, info)) = &self.token {
cache.invalidate(info);
}
}
/// Tell if this entry is actually cached.
pub fn cached(&self) -> bool {
self.token.is_some()
}
}
impl<C: Cache> Deref for Cached<C> {
type Target = C::Value;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<C: Cache> DerefMut for Cached<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
} }
} }

293
proxy/src/cache/timed_lru.rs vendored Normal file
View File

@@ -0,0 +1,293 @@
#[cfg(test)]
mod tests;
use super::{Cache, Cached, InvalidationToken, LookupInfo};
use ref_cast::RefCast;
use std::{
any::Any,
borrow::Borrow,
hash::Hash,
sync::Arc,
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<Key<K>, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
}
#[derive(RefCast, Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Query<Q: ?Sized>(Q);
#[derive(Hash, PartialEq, Eq)]
#[repr(transparent)]
struct Key<T: ?Sized>(Arc<T>);
/// It's impossible to implement this without [`Key`] & [`Query`]:
/// * We can't implement std traits for [`Arc`].
/// * Even if we could, it'd conflict with `impl<T> Borrow<T> for T`.
impl<Q, T> Borrow<Query<Q>> for Key<T>
where
Q: ?Sized,
T: Borrow<Q>,
{
#[inline(always)]
fn borrow(&self) -> &Query<Q> {
RefCast::ref_cast(self.0.as_ref().borrow())
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: Arc<T>,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
}
/// Get the number of entries in the cache.
/// Note that this method will not try to evict stale entries.
pub fn size(&self) -> usize {
self.cache.lock().len()
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, F, R>(&self, key: &Q, extract: F) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
F: FnOnce(&Arc<K>, &Entry<V>) -> R,
{
let key: &Query<Q> = RefCast::ref_cast(key);
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(&raw_entry.key().0, entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
?created_at,
old_expires_at = ?expires_at,
new_expires_at = ?deadline,
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Instant) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(Key(key), entry)
.map(|entry| entry.value);
debug!(
?created_at,
?expires_at,
replaced = old.is_some(),
"created a cache entry"
);
(old, created_at)
}
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn remove_raw<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let key: &Query<Q> = RefCast::ref_cast(key);
let mut cache = self.cache.lock();
cache.remove(key);
debug!("removed a cache entry");
}
}
/// Convenient wrappers for raw methods.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
pub fn get<Q>(&self, key: &Q) -> Option<Cached<V>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: Arc::clone(key),
};
Cached {
token: Some(Self::invalidation_token(self, info)),
value: entry.value.clone(),
}
})
}
pub fn insert(&self, key: Arc<K>, value: Arc<V>) -> (Option<Arc<V>>, Cached<V>) {
let (old, created_at) = self.insert_raw(key.clone(), value.clone());
let info = LookupInfo { created_at, key };
let cached = Cached {
token: Some(Self::invalidation_token(self, info)),
value,
};
(old, cached)
}
pub fn remove<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.remove_raw(key)
}
}
/// Implementation details of the entry invalidation machinery.
impl<K: Hash + Eq + Sync + Send + 'static, V> TimedLru<K, V>
where
Self: Sync + Send,
{
/// This is a proper (safe) way to create an invalidation token for [`TimedLru`].
fn invalidation_token(&self, info: LookupInfo<Arc<K>>) -> InvalidationToken<'_> {
InvalidationToken {
cache: self,
info: LookupInfo {
created_at: info.created_at,
key: info.key,
},
}
}
}
/// This implementation depends on [`Self::invalidation_token`].
impl<K: Hash + Eq + 'static, V> Cache for TimedLru<K, V> {
fn invalidate_entry(&self, info: LookupInfo<&(dyn Any + Sync + Send)>) {
let info = LookupInfo {
created_at: info.created_at,
// NOTE: it's important to downcast to the correct type!
key: info.key.downcast_ref::<K>().expect("bad key type"),
};
self.invalidate_raw(info)
}
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: LookupInfo<&K>) {
let key: &Query<K> = RefCast::ref_cast(info.key);
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
?created_at,
?expires_at,
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
}

104
proxy/src/cache/timed_lru/tests.rs vendored Normal file
View File

@@ -0,0 +1,104 @@
use super::*;
/// Check that we can define the cache for certain types.
#[test]
fn definition() {
// Check for trivial yet possible types.
let cache = TimedLru::<(), ()>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&());
// Now for something less trivial.
let cache = TimedLru::<String, String>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), Default::default());
let _ = cache.get(&String::default());
let _ = cache.get("str should work");
// It should also work for non-cloneable values.
struct NoClone;
let cache = TimedLru::<Box<str>, NoClone>::new("test", 128, Duration::from_secs(0));
let _ = cache.insert(Default::default(), NoClone.into());
let _ = cache.get(&Box::<str>::from("boxed str"));
let _ = cache.get("str should work");
}
#[test]
fn insert() {
const CAPACITY: usize = 2;
let cache = TimedLru::<String, u32>::new("test", CAPACITY, Duration::from_secs(10));
assert_eq!(cache.size(), 0);
let key = Arc::new(String::from("key"));
let (old, cached) = cache.insert(key.clone(), 42.into());
assert_eq!(old, None);
assert_eq!(*cached, 42);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(key, 1.into());
assert_eq!(old.as_deref(), Some(&42));
assert_eq!(*cached, 1);
assert_eq!(cache.size(), 1);
let (old, cached) = cache.insert(Arc::new("N1".to_owned()), 10.into());
assert_eq!(old, None);
assert_eq!(*cached, 10);
assert_eq!(cache.size(), 2);
let (old, cached) = cache.insert(Arc::new("N2".to_owned()), 20.into());
assert_eq!(old, None);
assert_eq!(*cached, 20);
assert_eq!(cache.size(), 2);
}
#[test]
fn get_none() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let cached = cache.get("missing");
assert!(matches!(cached, None));
}
#[test]
fn invalidation_simple() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let (_, cached) = cache.insert(String::from("key").into(), 100.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 0);
assert!(matches!(cache.get("key"), None));
}
#[test]
fn invalidation_preserve_newer() {
let cache = TimedLru::<String, u32>::new("test", 2, Duration::from_secs(10));
let key = Arc::new(String::from("key"));
let (_, cached) = cache.insert(key.clone(), 100.into());
assert_eq!(cache.size(), 1);
let _ = cache.insert(key.clone(), 200.into());
assert_eq!(cache.size(), 1);
cached.invalidate();
assert_eq!(cache.size(), 1);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&200));
}
#[test]
fn auto_expiry() {
let lifetime = Duration::from_millis(300);
let cache = TimedLru::<String, u32>::new("test", 2, lifetime);
let key = Arc::new(String::from("key"));
let _ = cache.insert(key.clone(), 42.into());
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), Some(&42));
std::thread::sleep(lifetime);
let cached = cache.get(key.as_ref());
assert_eq!(cached.as_deref(), None);
}

View File

@@ -1,13 +1,28 @@
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError}; use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::messages::{DatabaseInfo, MetricsAuxInfo},
console::CachedNodeInfo,
error::UserFacingError,
};
use futures::{FutureExt, TryFutureExt}; use futures::{FutureExt, TryFutureExt};
use itertools::Itertools; use itertools::Itertools;
use pq_proto::StartupMessageParams; use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration}; use std::{
io,
net::SocketAddr,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use thiserror::Error; use thiserror::Error;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::MakeTlsConnect;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
/// Should we allow self-signed certificates in TLS connections?
/// Most definitely, this shouldn't be allowed in production.
pub static ALLOW_SELF_SIGNED_COMPUTE: AtomicBool = AtomicBool::new(false);
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -42,6 +57,88 @@ impl UserFacingError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
#[derive(Clone)]
pub enum Password {
/// A regular cleartext password.
ClearText(Vec<u8>),
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
ScramKeys(ScramKeys),
}
pub enum ComputeNode {
/// Route via link auth.
Link(DatabaseInfo),
/// Regular compute node.
Static {
password: Password,
info: CachedNodeInfo,
},
}
impl ComputeNode {
/// Get metrics auxiliary info.
pub fn metrics_aux_info(&self) -> &MetricsAuxInfo {
match self {
Self::Link(info) => &info.aux,
Self::Static { info, .. } => &info.aux,
}
}
/// Invalidate compute node info if it's cached.
pub fn invalidate(&self) -> bool {
if let Self::Static { info, .. } = self {
warn!("invalidating compute node info cache entry");
info.invalidate();
return true;
}
false
}
/// Turn compute node info into a postgres connection config.
pub fn to_conn_config(&self) -> ConnCfg {
let mut config = ConnCfg::new();
let (host, port) = match self {
Self::Link(info) => {
// NB: use pre-supplied dbname, user and password for link auth.
// See `ConnCfg::set_startup_params` below.
config.0.dbname(&info.dbname).user(&info.user);
if let Some(password) = &info.password {
config.0.password(password.as_bytes());
}
(&info.host, info.port)
}
Self::Static { info, password } => {
// NB: setup auth keys (for SCRAM) or plaintext password.
match password {
Password::ClearText(text) => config.0.password(text),
Password::ScramKeys(keys) => {
use tokio_postgres::config::AuthKeys;
config.0.auth_keys(AuthKeys::ScramSha256(keys.to_owned()))
}
};
(&info.address.host, info.address.port)
}
};
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
config.0.ssl_mode(if host.contains("--") {
// We need TLS connection with SNI info to properly route it.
tokio_postgres::config::SslMode::Require
} else {
tokio_postgres::config::SslMode::Disable
});
config.0.host(host).port(port);
config
}
}
/// A config for establishing a connection to compute node. /// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better. /// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it. /// Newtype allows us to implement methods on top of it.
@@ -51,43 +148,32 @@ pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines. /// Creation and initialization routines.
impl ConnCfg { impl ConnCfg {
pub fn new() -> Self { fn new() -> Self {
Self(Default::default()) Self(Default::default())
} }
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
/// Apply startup message params to the connection config. /// Apply startup message params to the connection config.
pub fn set_startup_params(&mut self, params: &StartupMessageParams) { pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config. // Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response. // Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) { if let (None, Some(user)) = (self.0.get_user(), params.get("user")) {
self.user(user); self.0.user(user);
} }
// Only set `dbname` if it's not present in the config. // Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response. // Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { if let (None, Some(dbname)) = (self.0.get_dbname(), params.get("database")) {
self.dbname(dbname); self.0.dbname(dbname);
} }
// Don't add `options` if they were only used for specifying a project. // Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup. // Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(params) { if let Some(options) = filtered_options(params) {
self.options(&options); self.0.options(&options);
} }
if let Some(app_name) = params.get("application_name") { if let Some(app_name) = params.get("application_name") {
self.application_name(app_name); self.0.application_name(app_name);
} }
// TODO: This is especially ugly... // TODO: This is especially ugly...
@@ -95,10 +181,10 @@ impl ConnCfg {
use tokio_postgres::config::ReplicationMode; use tokio_postgres::config::ReplicationMode;
match replication { match replication {
"true" | "on" | "yes" | "1" => { "true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical); self.0.replication_mode(ReplicationMode::Physical);
} }
"database" => { "database" => {
self.replication_mode(ReplicationMode::Logical); self.0.replication_mode(ReplicationMode::Logical);
} }
_other => {} _other => {}
} }
@@ -113,27 +199,6 @@ impl ConnCfg {
} }
} }
impl std::ops::Deref for ConnCfg {
type Target = tokio_postgres::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg { impl ConnCfg {
/// Establish a raw TCP connection to the compute node. /// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> { async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream, &str)> {
@@ -220,16 +285,13 @@ pub struct PostgresConnection {
} }
impl ConnCfg { impl ConnCfg {
async fn do_connect( async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw().await?; let (socket_addr, stream, host) = self.connect_raw().await?;
let tls_connector = native_tls::TlsConnector::builder() let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute) .danger_accept_invalid_certs(ALLOW_SELF_SIGNED_COMPUTE.load(Ordering::Relaxed))
.build() .build()?;
.unwrap();
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector); let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?; let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
@@ -261,11 +323,8 @@ impl ConnCfg {
} }
/// Connect to a corresponding compute node. /// Connect to a corresponding compute node.
pub async fn connect( pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
&self, self.do_connect()
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
self.do_connect(allow_self_signed_compute)
.inspect_err(|err| { .inspect_err(|err| {
// Immediately log the error we have at our disposal. // Immediately log the error we have at our disposal.
error!("couldn't connect to compute node: {err}"); error!("couldn't connect to compute node: {err}");

View File

@@ -12,7 +12,6 @@ pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>, pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>, pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>, pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -211,8 +210,11 @@ pub struct CacheOptions {
} }
impl CacheOptions { impl CacheOptions {
/// Default options for [`crate::auth::caches::AuthInfoCache`].
pub const DEFAULT_AUTH_INFO: &str = "size=4000,ttl=5s";
/// Default options for [`crate::auth::caches::NodeInfoCache`]. /// Default options for [`crate::auth::caches::NodeInfoCache`].
pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=5m"; pub const DEFAULT_NODE_INFO: &str = "size=4000,ttl=5m";
/// Parse cache options passed via cmdline. /// Parse cache options passed via cmdline.
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`]. /// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].

View File

@@ -6,12 +6,17 @@ pub mod messages;
/// Wrappers for console APIs and their mocks. /// Wrappers for console APIs and their mocks.
pub mod provider; pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo}; pub use provider::{errors, Api, ConsoleReqExtra};
pub use provider::{AuthInfo, NodeInfo};
pub use provider::{CachedAuthInfo, CachedNodeInfo};
/// Various cache-related types. /// Various cache-related types.
pub mod caches { pub mod caches {
pub use super::provider::{ApiCaches, NodeInfoCache}; pub use super::provider::{ApiCaches, AuthInfoCache, AuthInfoCacheKey, NodeInfoCache};
} }
/// Console's management API. /// Console's management API.
pub mod mgmt; pub mod mgmt;
/// Console's notification bus.
pub mod notifications;

View File

@@ -22,11 +22,37 @@ impl fmt::Debug for GetRoleSecret {
} }
} }
/// Represents compute node's host & port pair.
#[derive(Debug)]
pub struct PgEndpoint {
pub host: Box<str>,
pub port: u16,
}
impl<'de> Deserialize<'de> for PgEndpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let raw = String::deserialize(deserializer)?;
let (host, port) = raw
.rsplit_once(':')
.ok_or_else(|| Error::custom(format!("bad compute address: {raw}")))?;
Ok(PgEndpoint {
host: host.into(),
port: port.parse().map_err(Error::custom)?,
})
}
}
/// Response which holds compute node's `host:port` pair. /// Response which holds compute node's `host:port` pair.
/// Returned by the `/proxy_wake_compute` API method. /// Returned by the `/proxy_wake_compute` API method.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct WakeCompute { pub struct WakeCompute {
pub address: Box<str>, pub address: PgEndpoint,
pub aux: MetricsAuxInfo, pub aux: MetricsAuxInfo,
} }
@@ -187,4 +213,24 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn parse_wake_compute() -> anyhow::Result<()> {
let _: WakeCompute = serde_json::from_value(json!({
"address": "127.0.0.1:5432",
"aux": dummy_aux(),
}))?;
let _: WakeCompute = serde_json::from_value(json!({
"address": "[::1]:5432",
"aux": dummy_aux(),
}))?;
serde_json::from_value::<WakeCompute>(json!({
"address": "localhost:5432",
"aux": dummy_aux(),
}))?;
Ok(())
}
} }

View File

@@ -0,0 +1,70 @@
use crate::console::caches::{ApiCaches, AuthInfoCacheKey};
use futures::StreamExt;
use serde::Deserialize;
const CHANNEL_NAME: &str = "proxy_notifications";
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum Notification<'a> {
#[serde(rename = "password_changed")]
PasswordChanged { project: &'a str, role: &'a str },
}
#[tracing::instrument(skip(caches))]
fn handle_message(msg: redis::Msg, caches: &ApiCaches) -> anyhow::Result<()> {
let payload: String = msg.get_payload()?;
use Notification::*;
match serde_json::from_str(&payload) {
Ok(PasswordChanged { project, role }) => {
let key = AuthInfoCacheKey {
project: project.into(),
role: role.into(),
};
tracing::info!(key = ?key, "invalidating auth info");
caches.auth_info.remove(&key);
}
Err(e) => tracing::error!("broken message: {e}"),
}
Ok(())
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main(url: String, caches: &ApiCaches) -> anyhow::Result<()> {
let client = redis::Client::open(url.as_ref())?;
let mut conn = client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
conn.subscribe(CHANNEL_NAME).await?;
let mut stream = conn.on_message();
while let Some(msg) = stream.next().await {
handle_message(msg, caches)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_notification() -> anyhow::Result<()> {
let text = json!({
"type": "password_changed",
"project": "very-nice",
"role": "borat",
})
.to_string();
let _: Notification = serde_json::from_str(&text)?;
Ok(())
}
}

View File

@@ -1,14 +1,13 @@
pub mod mock; pub mod mock;
pub mod neon; pub mod neon;
use super::messages::MetricsAuxInfo; use super::messages::{MetricsAuxInfo, PgEndpoint};
use crate::{ use crate::{
auth::ClientCredentials, auth::ClientCredentials,
cache::{timed_lru, TimedLru}, cache::{types::Cached, TimedLru},
compute, scram, scram,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::Arc;
pub mod errors { pub mod errors {
use crate::{ use crate::{
@@ -112,11 +111,9 @@ pub mod errors {
} }
} }
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WakeComputeError { pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(Box<str>),
#[error(transparent)] #[error(transparent)]
ApiError(ApiError), ApiError(ApiError),
} }
@@ -132,10 +129,7 @@ pub mod errors {
fn to_string_client(&self) -> String { fn to_string_client(&self) -> String {
use WakeComputeError::*; use WakeComputeError::*;
match self { match self {
// We shouldn't show user the address even if it's broken. // API might return a meaningful error.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(), ApiError(e) => e.to_string_client(),
} }
} }
@@ -160,23 +154,26 @@ pub enum AuthInfo {
} }
/// Info for establishing a connection to a compute node. /// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before! /// This struct is cached, so we shouldn't store any user creds here.
#[derive(Clone)]
pub struct NodeInfo { pub struct NodeInfo {
/// Compute node connection params. /// Address of the compute node.
/// It's sad that we have to clone this, but this will improve pub address: PgEndpoint,
/// once we migrate to a bespoke connection logic.
pub config: compute::ConnCfg,
/// Labels for proxy's metrics. /// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>, pub aux: MetricsAuxInfo,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,
} }
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>; pub type NodeInfoCache = TimedLru<Box<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; pub type CachedNodeInfo = Cached<NodeInfo>;
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct AuthInfoCacheKey {
pub project: Box<str>,
pub role: Box<str>,
}
pub type AuthInfoCache = TimedLru<AuthInfoCacheKey, AuthInfo>;
pub type CachedAuthInfo = Cached<AuthInfo>;
/// This will allocate per each call, but the http requests alone /// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine. /// already require a few allocations, so it should be fine.
@@ -187,7 +184,7 @@ pub trait Api {
&self, &self,
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>, creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>; ) -> Result<Option<CachedAuthInfo>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info. /// Wake up the compute node and return the corresponding connection info.
async fn wake_compute( async fn wake_compute(
@@ -199,6 +196,8 @@ pub trait Api {
/// Various caches for [`console`]. /// Various caches for [`console`].
pub struct ApiCaches { pub struct ApiCaches {
/// Cache for the `wake_compute` API method. /// Cache for the `get_auth_info` method.
pub auth_info: AuthInfoCache,
/// Cache for the `wake_compute` method.
pub node_info: NodeInfoCache, pub node_info: NodeInfoCache,
} }

View File

@@ -2,13 +2,14 @@
use super::{ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError}, errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, ConsoleReqExtra, PgEndpoint,
}; };
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; use super::{AuthInfo, NodeInfo};
use super::{CachedAuthInfo, CachedNodeInfo};
use crate::{auth::ClientCredentials, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait; use async_trait::async_trait;
use futures::TryFutureExt; use futures::TryFutureExt;
use thiserror::Error; use thiserror::Error;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{error, info, info_span, warn, Instrument};
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -84,19 +85,15 @@ impl Api {
} }
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> { async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new(); let host = self.endpoint.host_str().unwrap_or("localhost").into();
config let port = self.endpoint.port().unwrap_or(5432);
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(SslMode::Disable);
let node = NodeInfo { let info = NodeInfo {
config, address: PgEndpoint { host, port },
aux: Default::default(), aux: Default::default(),
allow_self_signed_compute: false,
}; };
Ok(node) Ok(info)
} }
} }
@@ -107,8 +104,9 @@ impl super::Api for Api {
&self, &self,
_extra: &ConsoleReqExtra<'_>, _extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>, creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> { ) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(creds).await let res = self.do_get_auth_info(creds).await?;
Ok(res.map(CachedAuthInfo::new_uncached))
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
@@ -117,9 +115,8 @@ impl super::Api for Api {
_extra: &ConsoleReqExtra<'_>, _extra: &ConsoleReqExtra<'_>,
_creds: &ClientCredentials<'_>, _creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> { ) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute() let res = self.do_wake_compute().await?;
.map_ok(CachedNodeInfo::new_uncached) Ok(CachedNodeInfo::new_uncached(res))
.await
} }
} }

View File

@@ -3,18 +3,19 @@
use super::{ use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError}, errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, ApiCaches, ConsoleReqExtra,
}; };
use crate::{auth::ClientCredentials, compute, http, scram}; use super::{AuthInfo, AuthInfoCacheKey, CachedAuthInfo};
use super::{CachedNodeInfo, NodeInfo};
use crate::{auth::ClientCredentials, http, scram};
use async_trait::async_trait; use async_trait::async_trait;
use futures::TryFutureExt; use futures::TryFutureExt;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{error, info, info_span, warn, Instrument};
#[derive(Clone)] #[derive(Clone)]
pub struct Api { pub struct Api {
endpoint: http::Endpoint, pub endpoint: http::Endpoint,
caches: &'static ApiCaches, pub caches: &'static ApiCaches,
} }
impl Api { impl Api {
@@ -91,22 +92,9 @@ impl Api {
let response = self.endpoint.execute(request).await?; let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?; let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo { let node = NodeInfo {
config, address: body.address,
aux: body.aux.into(), aux: body.aux,
allow_self_signed_compute: false,
}; };
Ok(node) Ok(node)
@@ -124,8 +112,28 @@ impl super::Api for Api {
&self, &self,
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>, creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> { ) -> Result<Option<CachedAuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await let key = AuthInfoCacheKey {
project: creds.project().expect("impossible").into(),
role: creds.user.into(),
};
// Check if we already have a cached auth info for this project + user combo.
// Beware! We shouldn't flush this for unsuccessful auth attempts, otherwise
// the cache makes no sense whatsoever in the presence of unfaithful clients.
// Instead, we snoop an invalidation queue to keep the cache up-to-date.
if let Some(cached) = self.caches.auth_info.get(&key) {
info!(key = ?key, "found cached auth info");
return Ok(Some(cached));
}
let info = self.do_get_auth_info(extra, creds).await?;
Ok(info.map(|info| {
info!(key = ?key, "creating a cache entry for auth info");
let (_, cached) = self.caches.auth_info.insert(key.into(), info.into());
cached
}))
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
@@ -134,20 +142,21 @@ impl super::Api for Api {
extra: &ConsoleReqExtra<'_>, extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>, creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> { ) -> Result<CachedNodeInfo, WakeComputeError> {
let key = creds.project().expect("impossible"); let key: Box<str> = creds.project().expect("impossible").into();
// Every time we do a wakeup http request, the compute node will stay up // Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy); // for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time, // The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency. // which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) { if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = key, "found cached compute node info"); info!(key = ?key, "found cached compute node info");
return Ok(cached); return Ok(cached);
} }
let node = self.do_wake_compute(extra, creds).await?; let info = self.do_wake_compute(extra, creds).await?;
let (_, cached) = self.caches.node_info.insert(key.into(), node);
info!(key = key, "created a cache entry for compute node info"); info!(key = ?key, "creating a cache entry for compute node info");
let (_, cached) = self.caches.node_info.insert(key.into(), info.into());
Ok(cached) Ok(cached)
} }
@@ -177,20 +186,3 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
error!("console responded with an error ({status}): {text}"); error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text }) Err(ApiError::Console { status, text })
} }
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
}

View File

@@ -175,14 +175,9 @@ pub async fn handle(
application_name: Some(APP_NAME), application_name: Some(APP_NAME),
}; };
let node = creds.wake_compute(&extra).await?.expect("msg"); let node = creds.wake_compute(&extra).await?.expect("msg");
let conf = node.value.config; let pg_endpoint = &node.value.address;
let port = *conf.get_ports().first().expect("no port"); let port = pg_endpoint.port;
let host = match conf.get_hosts().first().expect("no host") { let host = &pg_endpoint.host;
tokio_postgres::config::Host::Tcp(host) => host,
tokio_postgres::config::Host::Unix(_) => {
return Err(anyhow::anyhow!("unix socket is not supported"));
}
};
let request_content_length = match request.body().size_hint().upper() { let request_content_length = match request.body().size_hint().upper() {
Some(v) => v, Some(v) => v,

View File

@@ -4,7 +4,7 @@ mod tests;
use crate::{ use crate::{
auth::{self, backend::AuthSuccess}, auth::{self, backend::AuthSuccess},
cancellation::{self, CancelMap}, cancellation::{self, CancelMap},
compute::{self, PostgresConnection}, compute::{self, ComputeNode, PostgresConnection},
config::{ProxyConfig, TlsConfig}, config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo}, console::{self, messages::MetricsAuxInfo},
error::io_error, error::io_error,
@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await? async { result }.or_else(|e| stream.throw_error(e)).await?
}; };
let client = Client::new(stream, creds, &params, session_id, false); let client = Client::new(stream, creds, &params, session_id);
cancel_map cancel_map
.with_session(|session| client.connect_to_db(session, true)) .with_session(|session| client.connect_to_db(session, true))
.await .await
@@ -194,15 +194,7 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await? async { result }.or_else(|e| stream.throw_error(e)).await?
}; };
let allow_self_signed_compute = config.allow_self_signed_compute; let client = Client::new(stream, creds, &params, session_id);
let client = Client::new(
stream,
creds,
&params,
session_id,
allow_self_signed_compute,
);
cancel_map cancel_map
.with_session(|session| client.connect_to_db(session, false)) .with_session(|session| client.connect_to_db(session, false))
.await .await
@@ -283,61 +275,38 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
} }
} }
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", skip_all)]
async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
) -> Result<PostgresConnection, compute::ConnectionError> {
// If we couldn't connect, a cached connection info might be to blame
// (e.g. the compute node's address might've changed at the wrong time).
// Invalidate the cache entry (if any) to prevent subsequent errors.
let invalidate_cache = |_: &compute::ConnectionError| {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
node_info.invalidate();
}
let label = match is_cached {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
};
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(allow_self_signed_compute)
.inspect_err(invalidate_cache)
.await
}
/// Try to connect to the compute node, retrying if necessary. /// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`. /// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn connect_to_compute( async fn connect_to_compute(
node_info: &mut console::CachedNodeInfo, node_info: &mut ComputeNode,
params: &StartupMessageParams, params: &StartupMessageParams,
extra: &console::ConsoleReqExtra<'_>, extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<PostgresConnection, compute::ConnectionError> { ) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; let mut num_retries = NUM_RETRIES_WAKE_COMPUTE;
loop { loop {
// Apply startup params to the (possibly, cached) compute node info. let mut config = node_info.to_conn_config();
node_info.config.set_startup_params(params); config.set_startup_params(params);
match connect_to_compute_once(node_info).await { match config.connect().await {
Err(e) if num_retries > 0 => { Err(e) if num_retries > 0 => {
info!("compute node's state has changed; requesting a wake-up"); let label = match node_info.invalidate() {
match creds.wake_compute(extra).map_err(io_error).await? { true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
let res = creds.wake_compute(extra).map_err(io_error).await?;
match (res, &node_info) {
// Update `node_info` and try one more time. // Update `node_info` and try one more time.
Some(mut new) => { (Some(new), ComputeNode::Static { password, .. }) => {
new.config.reuse_password(&node_info.config); *node_info = ComputeNode::Static {
*node_info = new; password: password.to_owned(),
info: new,
}
} }
// Link auth doesn't work that way, so we just exit. // Link auth doesn't work that way, so we just exit.
None => return Err(e), _ => return Err(e),
} }
} }
other => return other, other => return other,
@@ -430,8 +399,6 @@ struct Client<'a, S> {
params: &'a StartupMessageParams, params: &'a StartupMessageParams,
/// Unique connection ID. /// Unique connection ID.
session_id: uuid::Uuid, session_id: uuid::Uuid,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
} }
impl<'a, S> Client<'a, S> { impl<'a, S> Client<'a, S> {
@@ -441,14 +408,12 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams, params: &'a StartupMessageParams,
session_id: uuid::Uuid, session_id: uuid::Uuid,
allow_self_signed_compute: bool,
) -> Self { ) -> Self {
Self { Self {
stream, stream,
creds, creds,
params, params,
session_id, session_id,
allow_self_signed_compute,
} }
} }
} }
@@ -468,7 +433,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
mut creds, mut creds,
params, params,
session_id, session_id,
allow_self_signed_compute,
} = self; } = self;
let extra = console::ConsoleReqExtra { let extra = console::ConsoleReqExtra {
@@ -491,19 +455,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
value: mut node_info, value: mut node_info,
} = auth_result; } = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds) let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
.or_else(|e| stream.throw_error(e)) .or_else(|e| stream.throw_error(e))
.await?; .await?;
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?; prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the // Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm // PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query // driver in pipeline mode sends startup, password and first query
// immediately after opening the connection. // immediately after opening the connection.
let (stream, read_buf) = stream.into_inner(); let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?; node.stream.write_all(&read_buf).await?;
proxy_pass(stream, node.stream, &node_info.aux).await
proxy_pass(stream, node.stream, node_info.metrics_aux_info()).await
} }
} }

View File

@@ -1845,13 +1845,15 @@ class VanillaPostgres(PgProtocol):
] ]
) )
def configure(self, options: List[str]): def configure(self, options: List[str]) -> "VanillaPostgres":
"""Append lines into postgresql.conf file.""" """Append lines into postgresql.conf file."""
assert not self.running assert not self.running
with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file: with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file:
conf_file.write("\n".join(options)) conf_file.write("\n".join(options))
def start(self, log_path: Optional[str] = None): return self
def start(self, log_path: Optional[str] = None) -> "VanillaPostgres":
assert not self.running assert not self.running
self.running = True self.running = True
@@ -1862,11 +1864,15 @@ class VanillaPostgres(PgProtocol):
["pg_ctl", "-w", "-D", str(self.pgdatadir), "-l", log_path, "start"] ["pg_ctl", "-w", "-D", str(self.pgdatadir), "-l", log_path, "start"]
) )
def stop(self): return self
def stop(self) -> "VanillaPostgres":
assert self.running assert self.running
self.running = False self.running = False
self.pg_bin.run_capture(["pg_ctl", "-w", "-D", str(self.pgdatadir), "stop"]) self.pg_bin.run_capture(["pg_ctl", "-w", "-D", str(self.pgdatadir), "stop"])
return self
def get_subdir_size(self, subdir) -> int: def get_subdir_size(self, subdir) -> int:
"""Return size of pgdatadir subdirectory in bytes.""" """Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(os.path.join(self.pgdatadir, subdir)) return get_dir_size(os.path.join(self.pgdatadir, subdir))
@@ -2035,6 +2041,17 @@ class NeonProxy(PgProtocol):
*["--auth-endpoint", self.pg_conn_url], *["--auth-endpoint", self.pg_conn_url],
] ]
@dataclass(frozen=True)
class Console(AuthBackend):
console_url: str
def extra_args(self) -> list[str]:
return [
# Postgres auth backend params
*["--auth-backend", "console"],
*["--auth-endpoint", self.console_url],
]
def __init__( def __init__(
self, self,
neon_binpath: Path, neon_binpath: Path,
@@ -2240,6 +2257,33 @@ def link_proxy(
yield proxy yield proxy
@pytest.fixture(scope="function")
def console_proxy(
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth."""
http_port = port_distributor.get_port()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
console_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
console_url = f"http://127.0.0.1:{console_port}"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Console(console_url),
) as proxy:
proxy.start()
yield proxy
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def static_proxy( def static_proxy(
vanilla_pg: VanillaPostgres, vanilla_pg: VanillaPostgres,

View File

@@ -1,11 +1,13 @@
import json import json
import logging
import subprocess import subprocess
from typing import Any, List from typing import Any, List, cast
import psycopg2 import psycopg2
import pytest import pytest
import requests import requests
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres from aiohttp import web
from fixtures.neon_fixtures import PSQL, NeonProxy, PortDistributor, VanillaPostgres
def test_proxy_select_1(static_proxy: NeonProxy): def test_proxy_select_1(static_proxy: NeonProxy):
@@ -225,3 +227,78 @@ def test_sql_over_http(static_proxy: NeonProxy):
res = q("drop table t") res = q("drop table t")
assert res["command"] == "DROP" assert res["command"] == "DROP"
assert res["rowCount"] is None assert res["rowCount"] is None
@pytest.mark.asyncio
async def test_compute_cache_invalidation(
port_distributor: PortDistributor, vanilla_pg: VanillaPostgres, console_proxy: NeonProxy
):
console_url = cast(NeonProxy.Console, console_proxy.auth_backend).console_url
logging.info(f"mocked console's url is {console_url}")
console_port = int(console_url.split(":")[-1])
logging.info(f"mocked console's port is {console_port}")
routes = web.RouteTableDef()
@routes.get("/proxy_get_role_secret")
async def get_role_secret(request):
# corresponds to password "password"
secret = ":".join(
[
"SCRAM-SHA-256$4096",
"t33UQcz/cs1D+n9INqThsw==$1NYlCbuxtK7YF2sgECBDTv1Myf8PpHJCT3RgKSXlZL0=",
"9iLeGY91MqBQ4ez1389Smo7h+STsJJ5jvu7kNofxj08=",
]
)
return web.json_response({"role_secret": secret})
wake_compute_called = 0
postgres_port = vanilla_pg.default_options["port"]
@routes.get("/proxy_wake_compute")
async def wake_compute(request):
nonlocal wake_compute_called
wake_compute_called += 1
nonlocal postgres_port
logging.info(f"compute's port is {postgres_port}")
return web.json_response(
{
"address": f"127.0.0.1:{postgres_port}",
"aux": {
"endpoint_id": "",
"project_id": "",
"branch_id": "",
},
}
)
console = web.Application()
console.add_routes(routes)
runner = web.AppRunner(console)
await runner.setup()
await web.TCPSite(runner, "127.0.0.1", console_port).start()
# Create a user we're going to use in the test sequence
user, password = "borat", "password"
vanilla_pg.start().safe_psql(f"create role {user} with login password '{password}'")
async def try_connect():
await console_proxy.connect_async(user=user, password=password, dbname="postgres")
assert wake_compute_called == 0
# Try connecting to compute
await try_connect()
assert wake_compute_called == 1
# Change compute's port
postgres_port = port_distributor.get_port()
vanilla_pg.stop().configure([f"port = {postgres_port}"]).start()
# Try connecting to compute
await try_connect()
assert wake_compute_called == 2

View File

@@ -42,7 +42,8 @@ regex = { version = "1" }
regex-syntax = { version = "0.6" } regex-syntax = { version = "0.6" }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "multipart", "rustls-tls"] } reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "multipart", "rustls-tls"] }
ring = { version = "0.16", features = ["std"] } ring = { version = "0.16", features = ["std"] }
rustls = { version = "0.20", features = ["dangerous_configuration"] } rustls-56bd22fc3884b12 = { package = "rustls", version = "0.20", features = ["dangerous_configuration"] }
rustls-647d43efb71741da = { package = "rustls", version = "0.21" }
scopeguard = { version = "1" } scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] } serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] } serde_json = { version = "1", features = ["raw_value"] }