[proxy] Implement compute node info cache (#3331)

This patch adds a timed LRU cache implementation and a compute node info cache on top of that.
Cache entries might expire on their own (default ttl=5mins) or become invalid due to real-world events,
e.g. compute node scale-to-zero event, so we add a connection retry loop with a wake-up call.

Solved problems:
- [x] Find a decent LRU implementation.
- [x] Implement timed LRU on top of that.
- [x] Cache results of `proxy_wake_compute` API call.
- [x] Don't invalidate newer cache entries for the same key.
- [x] Add cmdline configuration knobs (requires some refactoring).
- [x] Add failed connection estab metric.
- [x] Refactor auth backends to make things simpler (retries, cache
placement, etc).
- [x] Address review comments (add code comments + cleanup).
- [x] Retry `/proxy_wake_compute` if we couldn't connect to a compute
(e.g. stalled cache entry).
- [x] Add high-level description for `TimedLru`.

TODOs (will be addressed later):
- [ ] Add cache metrics (hit, spurious hit, miss).
- [ ] Synchronize http requests across concurrent per-client tasks
(https://github.com/neondatabase/neon/pull/3331#issuecomment-1399216069).
- [ ] Cache results of `proxy_get_role_secret` API call.
This commit is contained in:
Dmitry Ivanov
2023-02-01 17:11:41 +03:00
committed by GitHub
parent f1aece1ba0
commit ea0278cf27
24 changed files with 1401 additions and 736 deletions

27
Cargo.lock generated
View File

@@ -17,6 +17,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "ahash"
version = "0.8.2"
@@ -1515,6 +1526,9 @@ name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash 0.7.6",
]
[[package]]
name = "hashbrown"
@@ -1522,7 +1536,16 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
"ahash",
"ahash 0.8.2",
]
[[package]]
name = "hashlink"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa"
dependencies = [
"hashbrown 0.12.3",
]
[[package]]
@@ -2764,6 +2787,7 @@ dependencies = [
"futures",
"git-version",
"hashbrown 0.13.2",
"hashlink",
"hex",
"hmac",
"hostname",
@@ -4624,6 +4648,7 @@ dependencies = [
"futures-executor",
"futures-task",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
"itertools",
"libc",

View File

@@ -44,6 +44,7 @@ futures-core = "0.3"
futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.1"
hex = "0.4"
hex-literal = "0.3"
hmac = "0.12.1"

View File

@@ -6,58 +6,59 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
atty.workspace = true
base64.workspace = true
bstr.workspace = true
bytes = {workspace = true, features = ['serde'] }
clap.workspace = true
bytes = { workspace = true, features = ["serde"] }
chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
futures.workspace = true
git-version.workspace = true
hashbrown.workspace = true
hashlink.workspace = true
hex.workspace = true
hmac.workspace = true
hyper.workspace = true
hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
itertools.workspace = true
md5.workspace = true
metrics.workspace = true
once_cell.workspace = true
parking_lot.workspace = true
pin-project-lite.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = [ "json" ] }
reqwest = { workspace = true, features = ["json"] }
routerify.workspace = true
rustls.workspace = true
rustls-pemfile.workspace = true
rustls.workspace = true
scopeguard.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
socket2.workspace = true
thiserror.workspace = true
tokio.workspace = true
tls-listener.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tls-listener.workspace = true
tracing.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
utils.workspace = true
uuid.workspace = true
webpki-roots.workspace = true
x509-parser.workspace = true
metrics.workspace = true
pq_proto.workspace = true
utils.workspace = true
prometheus.workspace = true
humantime.workspace = true
hostname.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
async-trait.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::{BackendType, ConsoleReqExtra};
pub use backend::BackendType;
mod credentials;
pub use credentials::ClientCredentials;
@@ -12,7 +12,7 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use crate::error::UserFacingError;
use crate::{console, error::UserFacingError};
use std::io;
use thiserror::Error;
@@ -26,10 +26,10 @@ pub enum AuthErrorImpl {
Link(#[from] backend::LinkAuthError),
#[error(transparent)]
GetAuthInfo(#[from] backend::GetAuthInfoError),
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] backend::WakeComputeError),
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]

View File

@@ -1,48 +1,40 @@
mod postgres;
mod classic;
mod link;
use futures::TryFutureExt;
pub use link::LinkAuthError;
mod console;
pub use console::{GetAuthInfoError, WakeComputeError};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::messages::MetricsAuxInfo,
http, mgmt, stream, url,
waiters::{self, Waiter, Waiters},
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
Api,
},
stream, url,
};
use once_cell::sync::Lazy;
use std::borrow::Cow;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
/// Give caller an opportunity to wait for the cloud's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
/// A product of successful authentication.
pub struct AuthSuccess<T> {
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
pub reported_auth_ok: bool,
/// Something to be considered a positive result.
pub value: T,
}
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Extra query params we'd like to pass to the console.
pub struct ConsoleReqExtra<'a> {
/// A unique identifier for a connection.
pub session_id: uuid::Uuid,
/// Name of client application, if set.
pub application_name: Option<&'a str>,
impl<T> AuthSuccess<T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`AuthSuccess<T>`] to [`AuthSuccess<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> AuthSuccess<R> {
AuthSuccess {
reported_auth_ok: self.reported_auth_ok,
value: f(self.value),
}
}
}
/// This type serves two purposes:
@@ -53,12 +45,11 @@ pub struct ConsoleReqExtra<'a> {
/// * However, when we substitute `T` with [`ClientCredentials`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
#[derive(Debug)]
pub enum BackendType<'a, T> {
/// Current Cloud API (V2).
Console(Cow<'a, http::Endpoint>, T),
Console(Cow<'a, console::provider::neon::Api>, T),
/// Local mock of Cloud API (V2).
Postgres(Cow<'a, url::ApiUrl>, T),
Postgres(Cow<'a, console::provider::mock::Api>, T),
/// Authentication via a web browser.
Link(Cow<'a, url::ApiUrl>),
}
@@ -67,14 +58,8 @@ impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(endpoint, _) => fmt
.debug_tuple("Console")
.field(&endpoint.url().as_str())
.finish(),
Postgres(endpoint, _) => fmt
.debug_tuple("Postgres")
.field(&endpoint.as_str())
.finish(),
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
@@ -120,30 +105,16 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
}
}
/// A product of successful authentication.
pub struct AuthSuccess<T> {
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
pub reported_auth_ok: bool,
/// Something to be considered a positive result.
pub value: T,
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
pub struct NodeInfo {
/// Compute node connection params.
pub config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
}
impl BackendType<'_, ClientCredentials<'_>> {
// TODO: get rid of explicit lifetimes in this block (there's a bug in rustc).
// Read more: https://github.com/rust-lang/rust/issues/99190
// Alleged fix: https://github.com/rust-lang/rust/pull/89056
impl<'l> BackendType<'l, ClientCredentials<'_>> {
/// Do something special if user didn't provide the `project` parameter.
async fn try_password_hack(
&mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
async fn try_password_hack<'a>(
&'a mut self,
extra: &'a ConsoleReqExtra<'a>,
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<Option<AuthSuccess<CachedNodeInfo>>> {
use BackendType::*;
// If there's no project so far, that entails that client doesn't
@@ -179,33 +150,28 @@ impl BackendType<'_, ClientCredentials<'_>> {
// TODO: find a proper way to merge those very similar blocks.
let (mut node, payload) = match self {
Console(endpoint, creds) if creds.project.is_none() => {
Console(api, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload(client).await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());
let node = console::Api::new(endpoint, extra, &creds)
.wake_compute()
.await?;
let node = api.wake_compute(extra, &creds).await?;
(node, payload)
}
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
// This is a hack to allow cleartext password in secure connections (wss).
// This is a hack to allow cleartext password in secure connections (wss).
Console(api, creds) if creds.use_cleartext_password_flow => {
let payload = fetch_plaintext_password(client).await?;
let creds = creds.as_ref();
let node = console::Api::new(endpoint, extra, &creds)
.wake_compute()
.await?;
let node = api.wake_compute(extra, creds).await?;
(node, payload)
}
Postgres(endpoint, creds) if creds.project.is_none() => {
Postgres(api, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload(client).await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());
let node = postgres::Api::new(endpoint, &creds).wake_compute().await?;
let node = api.wake_compute(extra, &creds).await?;
(node, payload)
}
@@ -220,11 +186,11 @@ impl BackendType<'_, ClientCredentials<'_>> {
}
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
pub async fn authenticate<'a>(
&mut self,
extra: &'a ConsoleReqExtra<'a>,
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
use BackendType::*;
// Handle cases when `project` is missing in `creds`.
@@ -235,7 +201,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
}
let res = match self {
Console(endpoint, creds) => {
Console(api, creds) => {
info!(
user = creds.user,
project = creds.project(),
@@ -243,26 +209,40 @@ impl BackendType<'_, ClientCredentials<'_>> {
);
assert!(creds.project.is_some());
console::Api::new(&endpoint, extra, &creds)
.handle_user(client)
.await?
classic::handle_user(api.as_ref(), extra, creds, client).await?
}
Postgres(endpoint, creds) => {
Postgres(api, creds) => {
info!("performing mock authentication using a local postgres instance");
assert!(creds.project.is_some());
postgres::Api::new(&endpoint, &creds)
.handle_user(client)
.await?
classic::handle_user(api.as_ref(), extra, creds, client).await?
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
info!("performing link authentication");
link::handle_user(&url, client).await?
link::handle_user(url, client)
.await?
.map(CachedNodeInfo::new_uncached)
}
};
info!("user successfully authenticated");
Ok(res)
}
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute<'a>(
&self,
extra: &'a ConsoleReqExtra<'a>,
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
Link(_) => Ok(None),
}
}
}

View File

@@ -0,0 +1,61 @@
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
sasl, scram,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
pub(super) async fn handle_user(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let flow = AuthFlow::new(client);
let scram_keys = match info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
Some(compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
let mut node = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
})
}

View File

@@ -1,365 +0,0 @@
//! Cloud API V2.
use super::{AuthSuccess, ConsoleReqExtra, NodeInfo};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::messages::{ConsoleError, GetRoleSecret, WakeCompute},
error::{io_error, UserFacingError},
http, sasl, scram,
stream::PqStream,
};
use futures::TryFutureExt;
use reqwest::StatusCode as HttpStatusCode;
use std::future::Future;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, info_span, warn, Instrument};
/// A go-to error message which doesn't leak any detail.
const REQUEST_FAILED: &str = "Console request failed";
/// Common console API error.
#[derive(Debug, Error)]
pub enum ApiError {
/// Error returned by the console itself.
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
Console {
status: HttpStatusCode,
text: Box<str>,
},
/// Various IO errors like broken pipe or malformed payload.
#[error("{REQUEST_FAILED}: {0}")]
Transport(#[from] std::io::Error),
}
impl ApiError {
/// Returns HTTP status code if it's the reason for failure.
fn http_status_code(&self) -> Option<HttpStatusCode> {
use ApiError::*;
match self {
Console { status, .. } => Some(*status),
_ => None,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
match self {
// To minimize risks, only select errors are forwarded to users.
// Ask @neondatabase/control-plane for review before adding more.
Console { status, .. } => match *status {
HttpStatusCode::NOT_FOUND => {
// Status 404: failed to get a project-related resource.
format!("{REQUEST_FAILED}: endpoint cannot be found")
}
HttpStatusCode::NOT_ACCEPTABLE => {
// Status 406: endpoint is disabled (we don't allow connections).
format!("{REQUEST_FAILED}: endpoint is disabled")
}
HttpStatusCode::LOCKED => {
// Status 423: project might be in maintenance mode (or bad state).
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
}
_ => REQUEST_FAILED.to_owned(),
},
_ => REQUEST_FAILED.to_owned(),
}
}
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
impl From<reqwest::Error> for ApiError {
fn from(e: reqwest::Error) -> Self {
io_error(e).into()
}
}
#[derive(Debug, Error)]
pub enum GetAuthInfoError {
// We shouldn't include the actual secret here.
#[error("Console responded with a malformed auth secret")]
BadSecret,
#[error(transparent)]
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
match self {
// We absolutely should not leak any secrets!
BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(Box<str>),
#[error(transparent)]
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for WakeComputeError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a http::Endpoint,
extra: &'a ConsoleReqExtra<'a>,
creds: &'a ClientCredentials<'a>,
}
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
fn as_ref(&self) -> &ClientCredentials<'a> {
self.creds
}
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(
endpoint: &'a http::Endpoint,
extra: &'a ConsoleReqExtra<'a>,
creds: &'a ClientCredentials,
) -> Self {
Self {
endpoint,
extra,
creds,
}
}
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
}
}
impl Api<'_> {
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
("role", Some(self.creds.user)),
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = match parse_body::<GetRoleSecret>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
_otherwise => return Err(e.into()),
},
};
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthInfo::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Ok(Some(secret))
}
.map_err(crate::error::log_error)
.instrument(info_span!("get_auth_info", id = request_id))
.await
}
/// Wake up the compute node and return the corresponding connection info.
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
let mut config = compute::ConnCfg::new();
config
.host(host)
.port(port)
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(NodeInfo {
config,
aux: body.aux,
})
}
.map_err(crate::error::log_error)
.instrument(info_span!("wake_compute", id = request_id))
.await
}
}
/// Common logic for user handling in API V2.
/// We reuse this for a mock API implementation in [`super::postgres`].
pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
endpoint: &'a Endpoint,
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
) -> auth::Result<AuthSuccess<NodeInfo>>
where
Endpoint: AsRef<ClientCredentials<'a>>,
GetAuthInfo: Future<Output = Result<Option<AuthInfo>, GetAuthInfoError>>,
WakeCompute: Future<Output = Result<NodeInfo, WakeComputeError>>,
{
let creds = endpoint.as_ref();
info!("fetching user's authentication info");
let info = get_auth_info(endpoint).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let flow = AuthFlow::new(client);
let scram_keys = match info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
Some(compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
let mut node = wake_compute(endpoint).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
Ok(AuthSuccess {
reported_auth_ok: false,
value: node,
})
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: reqwest::Response,
) -> Result<T, ApiError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
}
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let body = response.json().await.unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ConsoleError {
error: "reason unclear (malformed error message)".into(),
}
});
let text = body.error;
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
}

View File

@@ -1,5 +1,11 @@
use super::{AuthSuccess, NodeInfo};
use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters};
use super::AuthSuccess;
use crate::{
auth, compute,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
};
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -47,7 +53,7 @@ pub fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub async fn handle_user(
pub(super) async fn handle_user(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
@@ -55,7 +61,7 @@ pub async fn handle_user(
let span = info_span!("link", psql_session_id = &psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id);
let db_info = super::with_waiter(psql_session_id, |waiter| async {
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
@@ -80,14 +86,14 @@ pub async fn handle_user(
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
config.password(password.as_ref());
}
Ok(AuthSuccess {
reported_auth_ok: true,
value: NodeInfo {
config,
aux: db_info.aux,
aux: db_info.aux.into(),
},
})
}

View File

@@ -33,6 +33,7 @@ impl UserFacingError for ClientCredsParseError {}
pub struct ClientCredentials<'a> {
pub user: &'a str,
pub dbname: &'a str,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
pub project: Option<Cow<'a, str>>,
/// If `True`, we'll use the old cleartext password flow. This is used for
/// websocket connections, which want to minimize the number of round trips.

304
proxy/src/cache.rs Normal file
View File

@@ -0,0 +1,304 @@
use std::{
borrow::Borrow,
hash::Hash,
ops::{Deref, DerefMut},
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
/// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation.
/// This is useful for [`timed_lru::Cached`].
pub trait Cache {
/// Entry's key.
type Key;
/// Entry's value.
type Value;
/// Used for entry invalidation.
type LookupInfo<Key>;
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
C::invalidate(self, info)
}
}
pub use timed_lru::TimedLru;
pub mod timed_lru {
use super::*;
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info)
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
}
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
value,
};
(old, cached)
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper.
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
})
}
}
/// Lookup information for key invalidation.
pub struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}
/// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache> {
/// Cache + lookup info.
token: Option<(C, C::LookupInfo<C::Key>)>,
/// The value itself.
pub value: C::Value,
}
impl<C: Cache> Cached<C> {
/// Place any entry into this wrapper; invalidation will be a no-op.
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
Self {
token: None,
value: value.into(),
}
}
/// Drop this entry from a cache if it's still there.
pub fn invalidate(&self) {
if let Some((cache, info)) = &self.token {
cache.invalidate(info);
}
}
/// Tell if this entry is actually cached.
pub fn cached(&self) -> bool {
self.token.is_some()
}
}
impl<C: Cache> Deref for Cached<C> {
type Target = C::Value;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<C: Cache> DerefMut for Cached<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
}

View File

@@ -1,6 +1,5 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use pq_proto::CancelKeyData;
use std::net::SocketAddr;
use tokio::net::TcpStream;
@@ -9,14 +8,15 @@ use tracing::info;
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
// NB: we should immediately release the lock after cloning the token.
let cancel_closure = self
.0
.lock()
.read()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
@@ -41,14 +41,14 @@ impl CancelMap {
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
self.0
.lock()
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.lock().remove(&key);
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
@@ -59,12 +59,12 @@ impl CancelMap {
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.0.lock().contains_key(&session.key)
self.0.read().contains_key(&session.key)
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.0.lock().is_empty()
self.0.read().is_empty()
}
}
@@ -115,7 +115,7 @@ impl Session<'_> {
info!("enabling query cancellation for this session");
self.cancel_map
.0
.lock()
.write()
.insert(self.key, Some(cancel_closure));
self.key

View File

@@ -42,14 +42,65 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
/// Construct a new connection config.
pub fn new() -> Self {
Self(Default::default())
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
}
}
/// Apply startup message params to the connection config.
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
if let Some(options) = params.options_raw() {
// We must drop all proxy-specific parameters.
#[allow(unstable_name_collisions)]
let options: String = options
.filter(|opt| !opt.starts_with("project="))
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
self.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.application_name(app_name);
}
// TODO: This is especially ugly...
if let Some(replication) = params.get("replication") {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
}
// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
//
// This and the reverse params problem can be better addressed
// in a bespoke connection machinery (a new library for that sake).
}
}
impl std::ops::Deref for ConnCfg {
@@ -132,50 +183,13 @@ pub struct PostgresConnection {
pub stream: TcpStream,
/// PostgreSQL connection parameters.
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
}
impl ConnCfg {
/// Connect to a corresponding compute node.
pub async fn connect(
mut self,
params: &StartupMessageParams,
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
if let Some(options) = params.options_raw() {
// We must drop all proxy-specific parameters.
#[allow(unstable_name_collisions)]
let options: String = options
.filter(|opt| !opt.starts_with("project="))
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
self.0.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.0.application_name(app_name);
}
// TODO: This is especially ugly...
if let Some(replication) = params.get("replication") {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.0.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.0.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
}
// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
//
// This and the reverse params problem can be better addressed
// in a bespoke connection machinery (a new library for that sake).
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
// TODO: establish a secure connection to the DB.
let (socket_addr, mut stream) = self.connect_raw().await?;
let (client, connection) = self.0.connect_raw(&mut stream, NoTls).await?;
@@ -189,8 +203,13 @@ impl ConnCfg {
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
let db = PostgresConnection { stream, params };
Ok((db, cancel_closure))
let connection = PostgresConnection {
stream,
params,
cancel_closure,
};
Ok(connection)
}
}

View File

@@ -1,16 +1,16 @@
use crate::auth;
use anyhow::{ensure, Context};
use std::sync::Arc;
use anyhow::{bail, ensure, Context};
use std::{str::FromStr, sync::Arc, time::Duration};
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection_config: Option<MetricCollectionConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
}
pub struct MetricCollectionConfig {
pub endpoint: reqwest::Url,
pub interval: std::time::Duration,
pub interval: Duration,
}
pub struct TlsConfig {
@@ -37,6 +37,7 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context(format!(
@@ -73,3 +74,80 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
common_name,
})
}
/// Helper for cmdline cache options parsing.
pub struct CacheOptions {
/// Max number of entries.
pub size: usize,
/// Entry's time-to-live.
pub ttl: Duration,
}
impl CacheOptions {
/// Default options for [`crate::auth::caches::NodeInfoCache`].
pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=5m";
/// Parse cache options passed via cmdline.
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None;
let mut ttl = None;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"size" => size = Some(value.parse()?),
"ttl" => ttl = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// TTL doesn't matter if cache is always empty.
if let Some(0) = size {
ttl.get_or_insert(Duration::default());
}
Ok(Self {
size: size.context("missing `size`")?,
ttl: ttl.context("missing `ttl`")?,
})
}
}
impl FromStr for CacheOptions {
type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse cache options '{options}'");
Self::parse(options).with_context(error)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cache_options() -> anyhow::Result<()> {
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
assert_eq!(size, 4096);
assert_eq!(ttl, Duration::from_secs(5 * 60));
let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
assert_eq!(size, 2);
assert_eq!(ttl, Duration::from_secs(4 * 60));
let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::from_secs(1));
let CacheOptions { size, ttl } = "size=0".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::default());
Ok(())
}
}

View File

@@ -3,3 +3,15 @@
/// Payloads used in the console's APIs.
pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
/// Various cache-related types.
pub mod caches {
pub use super::provider::{ApiCaches, NodeInfoCache};
}
/// Console's management API.
pub mod mgmt;

View File

@@ -63,13 +63,13 @@ impl KickSession<'_> {
/// Compute node connection params.
#[derive(Deserialize)]
pub struct DatabaseInfo {
pub host: String,
pub host: Box<str>,
pub port: u16,
pub dbname: String,
pub user: String,
pub dbname: Box<str>,
pub user: Box<str>,
/// Console always provides a password, but it might
/// be inconvenient for debug with local PG instance.
pub password: Option<String>,
pub password: Option<Box<str>>,
pub aux: MetricsAuxInfo,
}

View File

@@ -1,8 +1,9 @@
use crate::{
auth,
console::messages::{DatabaseInfo, KickSession},
waiters::{self, Waiter, Waiters},
};
use anyhow::Context;
use once_cell::sync::Lazy;
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::{
net::{TcpListener, TcpStream},
@@ -14,6 +15,25 @@ use utils::{
postgres_backend_async::QueryError,
};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
/// Give caller an opportunity to wait for the cloud's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Console management API listener thread.
/// It spawns console response handlers needed for the link auth.
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
@@ -76,7 +96,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query
let _enter = span.enter();
info!("got response: {:?}", resp.result);
match auth::backend::notify(resp.session_id, Ok(resp.result)) {
match notify(resp.session_id, Ok(resp.result)) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -0,0 +1,194 @@
pub mod mock;
pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::ClientCredentials,
cache::{timed_lru, TimedLru},
compute, scram,
};
use async_trait::async_trait;
use std::sync::Arc;
pub mod errors {
use crate::error::{io_error, UserFacingError};
use reqwest::StatusCode as HttpStatusCode;
use thiserror::Error;
/// A go-to error message which doesn't leak any detail.
const REQUEST_FAILED: &str = "Console request failed";
/// Common console API error.
#[derive(Debug, Error)]
pub enum ApiError {
/// Error returned by the console itself.
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
Console {
status: HttpStatusCode,
text: Box<str>,
},
/// Various IO errors like broken pipe or malformed payload.
#[error("{REQUEST_FAILED}: {0}")]
Transport(#[from] std::io::Error),
}
impl ApiError {
/// Returns HTTP status code if it's the reason for failure.
pub fn http_status_code(&self) -> Option<HttpStatusCode> {
use ApiError::*;
match self {
Console { status, .. } => Some(*status),
_ => None,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
match self {
// To minimize risks, only select errors are forwarded to users.
// Ask @neondatabase/control-plane for review before adding more.
Console { status, .. } => match *status {
HttpStatusCode::NOT_FOUND => {
// Status 404: failed to get a project-related resource.
format!("{REQUEST_FAILED}: endpoint cannot be found")
}
HttpStatusCode::NOT_ACCEPTABLE => {
// Status 406: endpoint is disabled (we don't allow connections).
format!("{REQUEST_FAILED}: endpoint is disabled")
}
HttpStatusCode::LOCKED => {
// Status 423: project might be in maintenance mode (or bad state).
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
}
_ => REQUEST_FAILED.to_owned(),
},
_ => REQUEST_FAILED.to_owned(),
}
}
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
impl From<reqwest::Error> for ApiError {
fn from(e: reqwest::Error) -> Self {
io_error(e).into()
}
}
#[derive(Debug, Error)]
pub enum GetAuthInfoError {
// We shouldn't include the actual secret here.
#[error("Console responded with a malformed auth secret")]
BadSecret,
#[error(transparent)]
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
match self {
// We absolutely should not leak any secrets!
BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(Box<str>),
#[error(transparent)]
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for WakeComputeError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
}
/// Extra query params we'd like to pass to the console.
pub struct ConsoleReqExtra<'a> {
/// A unique identifier for a connection.
pub session_id: uuid::Uuid,
/// Name of client application, if set.
pub application_name: Option<&'a str>,
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
pub struct NodeInfo {
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>,
}
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
#[async_trait]
pub trait Api {
/// Get the client's auth secret for authentication.
async fn get_auth_info(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
}

View File

@@ -1,21 +1,14 @@
//! Local mock of Cloud API V2.
//! Mock console backend which relies on a user-provided postgres instance.
use super::{
console::{self, AuthInfo, GetAuthInfoError, WakeComputeError},
AuthSuccess, NodeInfo,
};
use crate::{
auth::{self, ClientCredentials},
compute,
error::io_error,
scram,
stream::PqStream,
url::ApiUrl,
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span, warn, Instrument};
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Debug, Error)]
enum MockApiError {
@@ -23,49 +16,36 @@ enum MockApiError {
PasswordNotSet(tokio_postgres::Error),
}
impl From<MockApiError> for console::ApiError {
impl From<MockApiError> for ApiError {
fn from(e: MockApiError) -> Self {
io_error(e).into()
}
}
impl From<tokio_postgres::Error> for console::ApiError {
impl From<tokio_postgres::Error> for ApiError {
fn from(e: tokio_postgres::Error) -> Self {
io_error(e).into()
}
}
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials<'a>,
#[derive(Clone)]
pub struct Api {
endpoint: ApiUrl,
}
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
fn as_ref(&self) -> &ClientCredentials<'a> {
self.creds
}
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
impl Api {
pub fn new(endpoint: ApiUrl) -> Self {
Self { endpoint }
}
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<NodeInfo>> {
// We reuse user handling logic from a production module.
console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
pub fn url(&self) -> &str {
self.endpoint.as_str()
}
}
impl Api<'_> {
/// This implementation fetches the auth info from a local postgres instance.
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
async fn do_get_auth_info(
&self,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
async {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
@@ -75,7 +55,7 @@ impl Api<'_> {
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client.query(query, &[&self.creds.user]).await?;
let rows = client.query(query, &[&creds.user]).await?;
// We can get at most one row, because `rolname` is unique.
let row = match rows.get(0) {
@@ -84,7 +64,7 @@ impl Api<'_> {
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
None => {
warn!("user '{}' does not exist", self.creds.user);
warn!("user '{}' does not exist", creds.user);
return Ok(None);
}
};
@@ -98,23 +78,50 @@ impl Api<'_> {
Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
}
.map_err(crate::error::log_error)
.instrument(info_span!("get_auth_info", mock = self.endpoint.as_str()))
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
.await
}
/// We don't need to wake anything locally, so we just return the connection info.
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
async fn do_wake_compute(
&self,
creds: &ClientCredentials<'_>,
) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(self.creds.dbname)
.user(self.creds.user);
.dbname(creds.dbname)
.user(creds.user);
Ok(NodeInfo {
let node = NodeInfo {
config,
aux: Default::default(),
})
};
Ok(node)
}
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_auth_info(
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(creds).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute(creds)
.map_ok(CachedNodeInfo::new_uncached)
.await
}
}

View File

@@ -0,0 +1,196 @@
//! Production console backend.
use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use async_trait::async_trait;
use futures::TryFutureExt;
use reqwest::StatusCode as HttpStatusCode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
caches: &'static ApiCaches,
}
impl Api {
/// Construct an API object containing the auth parameters.
pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
Self { endpoint, caches }
}
pub fn url(&self) -> &str {
self.endpoint.url().as_str()
}
async fn do_get_auth_info(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.query(&[("session_id", extra.session_id)])
.query(&[
("application_name", extra.application_name),
("project", Some(creds.project().expect("impossible"))),
("role", Some(creds.user)),
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = match parse_body::<GetRoleSecret>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
_otherwise => return Err(e.into()),
},
};
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthInfo::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Ok(Some(secret))
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<NodeInfo, WakeComputeError> {
let project = creds.project().expect("impossible");
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.query(&[("session_id", extra.session_id)])
.query(&[
("application_name", extra.application_name),
("project", Some(project)),
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
let mut config = compute::ConnCfg::new();
config
.host(host)
.port(port)
.dbname(creds.dbname)
.user(creds.user);
let node = NodeInfo {
config,
aux: body.aux.into(),
};
Ok(node)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_auth_info(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key = creds.project().expect("impossible");
// Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) {
info!(key = key, "found cached compute node info");
return Ok(cached);
}
let node = self.do_wake_compute(extra, creds).await?;
let (_, cached) = self.caches.node_info.insert(key.into(), node);
info!(key = key, "created a cache entry for compute node info");
Ok(cached)
}
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: reqwest::Response,
) -> Result<T, ApiError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
}
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let body = response.json().await.unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ConsoleError {
error: "reason unclear (malformed error message)".into(),
}
});
let text = body.error;
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
}

View File

@@ -9,8 +9,7 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
let router = endpoint::make_router();
router.get("/v1/status", status_handler)
endpoint::make_router().get("/v1/status", status_handler)
}
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {

View File

@@ -1,6 +1,6 @@
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
use hyper::server::accept::{self};
use hyper::server::accept;
use hyper::server::conn::AddrIncoming;
use hyper::upgrade::Upgraded;
use hyper::{Body, Request, Response, StatusCode};
@@ -161,7 +161,7 @@ impl AsyncBufRead for WebSocketRW {
async fn serve_websocket(
websocket: HyperWebsocket,
config: &ProxyConfig,
config: &'static ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
hostname: Option<String>,

View File

@@ -5,6 +5,7 @@
//! in somewhat transparent manner (again via communication with control plane API).
mod auth;
mod cache;
mod cancellation;
mod compute;
mod config;
@@ -12,7 +13,6 @@ mod console;
mod error;
mod http;
mod metrics;
mod mgmt;
mod parse;
mod proxy;
mod sasl;
@@ -21,7 +21,6 @@ mod stream;
mod url;
mod waiters;
use ::metrics::set_build_info_metric;
use anyhow::{bail, Context};
use clap::{self, Arg};
use config::ProxyConfig;
@@ -29,8 +28,7 @@ use futures::FutureExt;
use std::{borrow::Cow, future::Future, net::SocketAddr};
use tokio::{net::TcpListener, task::JoinError};
use tracing::{info, info_span, Instrument};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
use utils::{project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
@@ -51,124 +49,133 @@ async fn main() -> anyhow::Result<()> {
// initialize sentry if SENTRY_DSN is provided
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
let arg_matches = cli().get_matches();
let tls_config = match (
arg_matches.get_one::<String>("tls-key"),
arg_matches.get_one::<String>("tls-cert"),
) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
let proxy_address: SocketAddr = arg_matches.get_one::<String>("proxy").unwrap().parse()?;
let mgmt_address: SocketAddr = arg_matches.get_one::<String>("mgmt").unwrap().parse()?;
let http_address: SocketAddr = arg_matches.get_one::<String>("http").unwrap().parse()?;
let metric_collection_config = match
(
arg_matches.get_one::<String>("metric-collection-endpoint"),
arg_matches.get_one::<String>("metric-collection-interval"),
) {
(Some(endpoint), Some(interval)) => {
Some(config::MetricCollectionConfig {
endpoint: endpoint.parse()?,
interval: humantime::parse_duration(interval)?,
})
}
(None, None) => None,
_ => bail!("either both or neither metric-collection-endpoint and metric-collection-interval must be specified"),
};
let auth_backend = match arg_matches
.get_one::<String>("auth-backend")
.unwrap()
.as_str()
{
"console" => {
let url = arg_matches
.get_one::<String>("auth-endpoint")
.unwrap()
.parse()?;
let endpoint = http::Endpoint::new(url, reqwest::Client::new());
auth::BackendType::Console(Cow::Owned(endpoint), ())
}
"postgres" => {
let url = arg_matches
.get_one::<String>("auth-endpoint")
.unwrap()
.parse()?;
auth::BackendType::Postgres(Cow::Owned(url), ())
}
"link" => {
let url = arg_matches.get_one::<String>("uri").unwrap().parse()?;
auth::BackendType::Link(Cow::Owned(url))
}
other => bail!("unsupported auth backend: {other}"),
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection_config,
}));
info!("Version: {GIT_VERSION}");
::metrics::set_build_info_metric(GIT_VERSION);
let args = cli().get_matches();
let config = build_config(&args)?;
info!("Authentication backend: {}", config.auth_backend);
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.get_one::<String>("http").unwrap().parse()?;
info!("Starting http on {http_address}");
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
let mgmt_address: SocketAddr = args.get_one::<String>("mgmt").unwrap().parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?;
let proxy_address: SocketAddr = args.get_one::<String>("proxy").unwrap().parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let mut tasks = vec![
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(proxy::task_main(config, proxy_listener)),
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)),
];
if let Some(wss_address) = arg_matches.get_one::<String>("wss") {
if let Some(wss_address) = args.get_one::<String>("wss") {
let wss_address: SocketAddr = wss_address.parse()?;
info!("Starting wss on {}", wss_address);
info!("Starting wss on {wss_address}");
let wss_listener = TcpListener::bind(wss_address).await?;
tasks.push(tokio::spawn(http::websocket::task_main(
wss_listener,
config,
)));
}
if let Some(metric_collection_config) = &config.metric_collection_config {
// TODO: refactor.
if let Some(metric_collection) = &config.metric_collection {
let hostname = hostname::get()?
.into_string()
.map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?;
tasks.push(tokio::spawn(
metrics::collect_metrics(
&metric_collection_config.endpoint,
metric_collection_config.interval,
&metric_collection.endpoint,
metric_collection.interval,
hostname,
)
.instrument(info_span!("collect_metrics")),
));
}
let tasks = tasks.into_iter().map(flatten_err);
set_build_info_metric(GIT_VERSION);
// This will block until all tasks have completed.
// Furthermore, the first one to fail will cancel the rest.
let tasks = tasks.into_iter().map(flatten_err);
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
Ok(())
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
let metric_collection = match (
args.get_one::<String>("metric-collection-endpoint"),
args.get_one::<String>("metric-collection-interval"),
) {
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
endpoint: endpoint.parse()?,
interval: humantime::parse_duration(interval)?,
}),
(None, None) => None,
_ => bail!(
"either both or neither metric-collection-endpoint \
and metric-collection-interval must be specified"
),
};
let auth_backend = match args.get_one::<String>("auth-backend").unwrap().as_str() {
"console" => {
let config::CacheOptions { size, ttl } = args
.get_one::<String>("wake-compute-cache")
.unwrap()
.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches {
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
}));
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
let endpoint = http::Endpoint::new(url, reqwest::Client::new());
let api = console::provider::neon::Api::new(endpoint, caches);
auth::BackendType::Console(Cow::Owned(api), ())
}
"postgres" => {
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
let api = console::provider::mock::Api::new(url);
auth::BackendType::Postgres(Cow::Owned(api), ())
}
"link" => {
let url = args.get_one::<String>("uri").unwrap().parse()?;
auth::BackendType::Link(Cow::Owned(url))
}
other => bail!("unsupported auth backend: {other}"),
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
}));
Ok(config)
}
fn cli() -> clap::Command {
clap::Command::new("Neon proxy/router")
.disable_help_flag(true)
@@ -235,16 +242,27 @@ fn cli() -> clap::Command {
.arg(
Arg::new("metric-collection-endpoint")
.long("metric-collection-endpoint")
.help("metric collection HTTP endpoint"),
.help("http endpoint to receive periodic metric updates"),
)
.arg(
Arg::new("metric-collection-interval")
.long("metric-collection-interval")
.help("metric collection interval"),
.help("how often metrics should be sent to a collection endpoint"),
)
.arg(
Arg::new("wake-compute-cache")
.long("wake-compute-cache")
.help("cache for `wake_compute` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO),
)
}
#[test]
fn verify_cli() {
cli().debug_assert();
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_cli() {
cli().debug_assert();
}
}

View File

@@ -2,9 +2,12 @@
mod tests;
use crate::{
auth,
auth::{self, backend::AuthSuccess},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
stream::{MeasuredStream, PqStream, Stream},
};
use anyhow::{bail, Context};
@@ -14,7 +17,10 @@ use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, info_span, Instrument};
use tracing::{error, info, info_span, warn, Instrument};
/// Number of times we should retry the `/proxy_wake_compute` http request.
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
@@ -35,6 +41,15 @@ static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounter> = Lazy::new(|| {
.unwrap()
});
static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_total",
"Number of connection failures (per kind).",
&["kind"],
)
.unwrap()
});
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_io_bytes_per_client",
@@ -82,11 +97,12 @@ pub async fn task_main(
}
}
// TODO(tech debt): unite this with its twin below.
pub async fn handle_ws_client(
config: &ProxyConfig,
config: &'static ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
stream: impl AsyncRead + AsyncWrite + Unpin,
hostname: Option<String>,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
@@ -99,7 +115,7 @@ pub async fn handle_ws_client(
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
let do_handshake = handshake(stream, None, cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
@@ -124,10 +140,10 @@ pub async fn handle_ws_client(
}
async fn handle_client(
config: &ProxyConfig,
config: &'static ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
@@ -136,7 +152,7 @@ async fn handle_client(
}
let tls = config.tls_config.as_ref();
let do_handshake = handshake(stream, tls, cancel_map).instrument(info_span!("handshake"));
let do_handshake = handshake(stream, tls, cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
@@ -165,6 +181,7 @@ async fn handle_client(
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
@@ -226,6 +243,133 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", skip_all)]
async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
) -> Result<PostgresConnection, compute::ConnectionError> {
// If we couldn't connect, a cached connection info might be to blame
// (e.g. the compute node's address might've changed at the wrong time).
// Invalidate the cache entry (if any) to prevent subsequent errors.
let invalidate_cache = |_: &compute::ConnectionError| {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
node_info.invalidate();
}
let label = match is_cached {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
};
node_info
.config
.connect()
.inspect_err(invalidate_cache)
.await
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
async fn connect_to_compute(
node_info: &mut console::CachedNodeInfo,
params: &StartupMessageParams,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
loop {
// Apply startup params to the (possibly, cached) compute node info.
node_info.config.set_startup_params(params);
match connect_to_compute_once(node_info).await {
Err(e) if num_retries > 0 => {
info!("compute node's state has changed; requesting a wake-up");
match creds.wake_compute(extra).map_err(io_error).await? {
// Update `node_info` and try one more time.
Some(mut new) => {
new.config.reuse_password(&node_info.config);
*node_info = new;
}
// Link auth doesn't work that way, so we just exit.
None => return Err(e),
}
}
other => return other,
}
num_retries -= 1;
info!("retrying after wake-up ({num_retries} attempts left)");
}
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
reported_auth_ok: bool,
session: cancellation::Session<'_>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Report authentication success if we haven't done this already.
// Note that we do this only (for the most part) after we've connected
// to a compute (see above) which performs its own authentication.
if !reported_auth_ok {
stream.write_message_noflush(&Be::AuthenticationOk)?;
}
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
}
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: &MetricsAuxInfo,
) -> anyhow::Result<()> {
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(client, |cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(compute, |cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
});
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
@@ -255,17 +399,17 @@ impl<'a, S> Client<'a, S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
let Self {
mut stream,
creds,
mut creds,
params,
session_id,
} = self;
let extra = auth::ConsoleReqExtra {
let extra = console::ConsoleReqExtra {
session_id, // aka this connection's id
application_name: params.get("application_name"),
};
@@ -278,54 +422,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
.instrument(info_span!("auth"))
.await?;
let node = auth_result.value;
let (db, cancel_closure) = node
.config
.connect(params)
let AuthSuccess {
reported_auth_ok,
value: mut node_info,
} = auth_result;
let node = connect_to_compute(&mut node_info, params, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
// Report authentication success if we haven't done this already.
// Note that we do this only (for the most part) after we've connected
// to a compute (see above) which performs its own authentication.
if !auth_result.reported_auth_ok {
stream.write_message_noflush(&Be::AuthenticationOk)?;
}
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &db.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
let mut db = MeasuredStream::new(db.stream, |cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
});
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
Ok(())
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
proxy_pass(stream.into_inner(), node.stream, &node_info.aux).await
}
}

View File

@@ -25,6 +25,7 @@ futures-channel = { version = "0.3", features = ["sink"] }
futures-executor = { version = "0.3" }
futures-task = { version = "0.3", default-features = false, features = ["std"] }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
hashbrown = { version = "0.12", features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] }
@@ -58,6 +59,7 @@ url = { version = "2", features = ["serde"] }
anyhow = { version = "1", features = ["backtrace"] }
bytes = { version = "1", features = ["serde"] }
either = { version = "1" }
hashbrown = { version = "0.12", features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] }