mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 08:52:56 +00:00
[proxy] Implement compute node info cache (#3331)
This patch adds a timed LRU cache implementation and a compute node info cache on top of that. Cache entries might expire on their own (default ttl=5mins) or become invalid due to real-world events, e.g. compute node scale-to-zero event, so we add a connection retry loop with a wake-up call. Solved problems: - [x] Find a decent LRU implementation. - [x] Implement timed LRU on top of that. - [x] Cache results of `proxy_wake_compute` API call. - [x] Don't invalidate newer cache entries for the same key. - [x] Add cmdline configuration knobs (requires some refactoring). - [x] Add failed connection estab metric. - [x] Refactor auth backends to make things simpler (retries, cache placement, etc). - [x] Address review comments (add code comments + cleanup). - [x] Retry `/proxy_wake_compute` if we couldn't connect to a compute (e.g. stalled cache entry). - [x] Add high-level description for `TimedLru`. TODOs (will be addressed later): - [ ] Add cache metrics (hit, spurious hit, miss). - [ ] Synchronize http requests across concurrent per-client tasks (https://github.com/neondatabase/neon/pull/3331#issuecomment-1399216069). - [ ] Cache results of `proxy_get_role_secret` API call.
This commit is contained in:
27
Cargo.lock
generated
27
Cargo.lock
generated
@@ -17,6 +17,17 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.2"
|
||||
@@ -1515,6 +1526,9 @@ name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
@@ -1522,7 +1536,16 @@ version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa"
|
||||
dependencies = [
|
||||
"hashbrown 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2764,6 +2787,7 @@ dependencies = [
|
||||
"futures",
|
||||
"git-version",
|
||||
"hashbrown 0.13.2",
|
||||
"hashlink",
|
||||
"hex",
|
||||
"hmac",
|
||||
"hostname",
|
||||
@@ -4624,6 +4648,7 @@ dependencies = [
|
||||
"futures-executor",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
"hashbrown 0.12.3",
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"libc",
|
||||
|
||||
@@ -44,6 +44,7 @@ futures-core = "0.3"
|
||||
futures-util = "0.3"
|
||||
git-version = "0.3"
|
||||
hashbrown = "0.13"
|
||||
hashlink = "0.8.1"
|
||||
hex = "0.4"
|
||||
hex-literal = "0.3"
|
||||
hmac = "0.12.1"
|
||||
|
||||
@@ -6,58 +6,59 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
atty.workspace = true
|
||||
base64.workspace = true
|
||||
bstr.workspace = true
|
||||
bytes = {workspace = true, features = ['serde'] }
|
||||
clap.workspace = true
|
||||
bytes = { workspace = true, features = ["serde"] }
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
hashbrown.workspace = true
|
||||
hashlink.workspace = true
|
||||
hex.workspace = true
|
||||
hmac.workspace = true
|
||||
hyper.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest = { workspace = true, features = [ "json" ] }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
routerify.workspace = true
|
||||
rustls.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
rustls.workspace = true
|
||||
scopeguard.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sha2.workspace = true
|
||||
socket2.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tls-listener.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tls-listener.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
url.workspace = true
|
||||
utils.workspace = true
|
||||
uuid.workspace = true
|
||||
webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
metrics.workspace = true
|
||||
pq_proto.workspace = true
|
||||
utils.workspace = true
|
||||
prometheus.workspace = true
|
||||
humantime.workspace = true
|
||||
hostname.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
async-trait.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::{BackendType, ConsoleReqExtra};
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ClientCredentials;
|
||||
@@ -12,7 +12,7 @@ use password_hack::PasswordHackPayload;
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
use crate::{console, error::UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -26,10 +26,10 @@ pub enum AuthErrorImpl {
|
||||
Link(#[from] backend::LinkAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] backend::GetAuthInfoError),
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] backend::WakeComputeError),
|
||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
|
||||
@@ -1,48 +1,40 @@
|
||||
mod postgres;
|
||||
mod classic;
|
||||
|
||||
mod link;
|
||||
use futures::TryFutureExt;
|
||||
pub use link::LinkAuthError;
|
||||
|
||||
mod console;
|
||||
pub use console::{GetAuthInfoError, WakeComputeError};
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::messages::MetricsAuxInfo,
|
||||
http, mgmt, stream, url,
|
||||
waiters::{self, Waiter, Waiters},
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::borrow::Cow;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
/// A product of successful authentication.
|
||||
pub struct AuthSuccess<T> {
|
||||
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Something to be considered a positive result.
|
||||
pub value: T,
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra<'a> {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: Option<&'a str>,
|
||||
impl<T> AuthSuccess<T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`AuthSuccess<T>`] to [`AuthSuccess<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> AuthSuccess<R> {
|
||||
AuthSuccess {
|
||||
reported_auth_ok: self.reported_auth_ok,
|
||||
value: f(self.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
@@ -53,12 +45,11 @@ pub struct ConsoleReqExtra<'a> {
|
||||
/// * However, when we substitute `T` with [`ClientCredentials`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
#[derive(Debug)]
|
||||
pub enum BackendType<'a, T> {
|
||||
/// Current Cloud API (V2).
|
||||
Console(Cow<'a, http::Endpoint>, T),
|
||||
Console(Cow<'a, console::provider::neon::Api>, T),
|
||||
/// Local mock of Cloud API (V2).
|
||||
Postgres(Cow<'a, url::ApiUrl>, T),
|
||||
Postgres(Cow<'a, console::provider::mock::Api>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(Cow<'a, url::ApiUrl>),
|
||||
}
|
||||
@@ -67,14 +58,8 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(endpoint, _) => fmt
|
||||
.debug_tuple("Console")
|
||||
.field(&endpoint.url().as_str())
|
||||
.finish(),
|
||||
Postgres(endpoint, _) => fmt
|
||||
.debug_tuple("Postgres")
|
||||
.field(&endpoint.as_str())
|
||||
.finish(),
|
||||
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
|
||||
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
@@ -120,30 +105,16 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A product of successful authentication.
|
||||
pub struct AuthSuccess<T> {
|
||||
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Something to be considered a positive result.
|
||||
pub value: T,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
pub config: compute::ConnCfg,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
impl BackendType<'_, ClientCredentials<'_>> {
|
||||
// TODO: get rid of explicit lifetimes in this block (there's a bug in rustc).
|
||||
// Read more: https://github.com/rust-lang/rust/issues/99190
|
||||
// Alleged fix: https://github.com/rust-lang/rust/pull/89056
|
||||
impl<'l> BackendType<'l, ClientCredentials<'_>> {
|
||||
/// Do something special if user didn't provide the `project` parameter.
|
||||
async fn try_password_hack(
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
|
||||
async fn try_password_hack<'a>(
|
||||
&'a mut self,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<Option<AuthSuccess<CachedNodeInfo>>> {
|
||||
use BackendType::*;
|
||||
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
@@ -179,33 +150,28 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
|
||||
// TODO: find a proper way to merge those very similar blocks.
|
||||
let (mut node, payload) = match self {
|
||||
Console(endpoint, creds) if creds.project.is_none() => {
|
||||
Console(api, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
let node = console::Api::new(endpoint, extra, &creds)
|
||||
.wake_compute()
|
||||
.await?;
|
||||
let node = api.wake_compute(extra, &creds).await?;
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
|
||||
// This is a hack to allow cleartext password in secure connections (wss).
|
||||
// This is a hack to allow cleartext password in secure connections (wss).
|
||||
Console(api, creds) if creds.use_cleartext_password_flow => {
|
||||
let payload = fetch_plaintext_password(client).await?;
|
||||
let creds = creds.as_ref();
|
||||
let node = console::Api::new(endpoint, extra, &creds)
|
||||
.wake_compute()
|
||||
.await?;
|
||||
let node = api.wake_compute(extra, creds).await?;
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Postgres(endpoint, creds) if creds.project.is_none() => {
|
||||
Postgres(api, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
let node = postgres::Api::new(endpoint, &creds).wake_compute().await?;
|
||||
let node = api.wake_compute(extra, &creds).await?;
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
@@ -220,11 +186,11 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
pub async fn authenticate<'a>(
|
||||
&mut self,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
// Handle cases when `project` is missing in `creds`.
|
||||
@@ -235,7 +201,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
}
|
||||
|
||||
let res = match self {
|
||||
Console(endpoint, creds) => {
|
||||
Console(api, creds) => {
|
||||
info!(
|
||||
user = creds.user,
|
||||
project = creds.project(),
|
||||
@@ -243,26 +209,40 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
);
|
||||
|
||||
assert!(creds.project.is_some());
|
||||
console::Api::new(&endpoint, extra, &creds)
|
||||
.handle_user(client)
|
||||
.await?
|
||||
classic::handle_user(api.as_ref(), extra, creds, client).await?
|
||||
}
|
||||
Postgres(endpoint, creds) => {
|
||||
Postgres(api, creds) => {
|
||||
info!("performing mock authentication using a local postgres instance");
|
||||
|
||||
assert!(creds.project.is_some());
|
||||
postgres::Api::new(&endpoint, &creds)
|
||||
.handle_user(client)
|
||||
.await?
|
||||
classic::handle_user(api.as_ref(), extra, creds, client).await?
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
link::handle_user(&url, client).await?
|
||||
|
||||
link::handle_user(url, client)
|
||||
.await?
|
||||
.map(CachedNodeInfo::new_uncached)
|
||||
}
|
||||
};
|
||||
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// When applicable, wake the compute node, gaining its connection info in the process.
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute<'a>(
|
||||
&self,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
61
proxy/src/auth/backend/classic.rs
Normal file
61
proxy/src/auth/backend/classic.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
pub(super) async fn handle_user(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
info!("fetching user's authentication info");
|
||||
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
|
||||
});
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match info {
|
||||
AuthInfo::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret);
|
||||
let client_key = match flow.begin(scram).await?.authenticate().await? {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
Some(compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
if let Some(keys) = scram_keys {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
node.config.auth_keys(AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
})
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
//! Cloud API V2.
|
||||
|
||||
use super::{AuthSuccess, ConsoleReqExtra, NodeInfo};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
error::{io_error, UserFacingError},
|
||||
http, sasl, scram,
|
||||
stream::PqStream,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use reqwest::StatusCode as HttpStatusCode;
|
||||
use std::future::Future;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
|
||||
Console {
|
||||
status: HttpStatusCode,
|
||||
text: Box<str>,
|
||||
},
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
fn http_status_code(&self) -> Option<HttpStatusCode> {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
Console { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
Console { status, .. } => match *status {
|
||||
HttpStatusCode::NOT_FOUND => {
|
||||
// Status 404: failed to get a project-related resource.
|
||||
format!("{REQUEST_FAILED}: endpoint cannot be found")
|
||||
}
|
||||
HttpStatusCode::NOT_ACCEPTABLE => {
|
||||
// Status 406: endpoint is disabled (we don't allow connections).
|
||||
format!("{REQUEST_FAILED}: endpoint is disabled")
|
||||
}
|
||||
HttpStatusCode::LOCKED => {
|
||||
// Status 423: project might be in maintenance mode (or bad state).
|
||||
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
},
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a http::Endpoint,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
|
||||
fn as_ref(&self) -> &ClientCredentials<'a> {
|
||||
self.creds
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(
|
||||
endpoint: &'a http::Endpoint,
|
||||
extra: &'a ConsoleReqExtra<'a>,
|
||||
creds: &'a ClientCredentials,
|
||||
) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
extra,
|
||||
creds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
pub(super) async fn handle_user(
|
||||
&'a self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Api<'_> {
|
||||
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", self.extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", self.extra.application_name),
|
||||
("project", Some(self.creds.project().expect("impossible"))),
|
||||
("role", Some(self.creds.user)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = match parse_body::<GetRoleSecret>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
Err(e) => match e.http_status_code() {
|
||||
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
|
||||
_otherwise => return Err(e.into()),
|
||||
},
|
||||
};
|
||||
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
|
||||
Ok(Some(secret))
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("get_auth_info", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", self.extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", self.extra.application_name),
|
||||
("project", Some(self.creds.project().expect("impossible"))),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
|
||||
Ok(NodeInfo {
|
||||
config,
|
||||
aux: body.aux,
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("wake_compute", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Common logic for user handling in API V2.
|
||||
/// We reuse this for a mock API implementation in [`super::postgres`].
|
||||
pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
endpoint: &'a Endpoint,
|
||||
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
|
||||
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>>
|
||||
where
|
||||
Endpoint: AsRef<ClientCredentials<'a>>,
|
||||
GetAuthInfo: Future<Output = Result<Option<AuthInfo>, GetAuthInfoError>>,
|
||||
WakeCompute: Future<Output = Result<NodeInfo, WakeComputeError>>,
|
||||
{
|
||||
let creds = endpoint.as_ref();
|
||||
|
||||
info!("fetching user's authentication info");
|
||||
let info = get_auth_info(endpoint).await?.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
|
||||
});
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match info {
|
||||
AuthInfo::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret);
|
||||
let client_key = match flow.begin(scram).await?.authenticate().await? {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
Some(compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = wake_compute(endpoint).await?;
|
||||
if let Some(keys) = scram_keys {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
node.config.auth_keys(AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: node,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: reqwest::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
}
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let body = response.json().await.unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ConsoleError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
}
|
||||
});
|
||||
|
||||
let text = body.error;
|
||||
error!("console responded with an error ({status}): {text}");
|
||||
Err(ApiError::Console { status, text })
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
use super::{AuthSuccess, NodeInfo};
|
||||
use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters};
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -47,7 +53,7 @@ pub fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
pub(super) async fn handle_user(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
@@ -55,7 +61,7 @@ pub async fn handle_user(
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
let db_info = super::with_waiter(psql_session_id, |waiter| async {
|
||||
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
@@ -80,14 +86,14 @@ pub async fn handle_user(
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: true,
|
||||
value: NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux,
|
||||
aux: db_info.aux.into(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
/// If `True`, we'll use the old cleartext password flow. This is used for
|
||||
/// websocket connections, which want to minimize the number of round trips.
|
||||
|
||||
304
proxy/src/cache.rs
Normal file
304
proxy/src/cache.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`timed_lru::Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
pub use timed_lru::TimedLru;
|
||||
pub mod timed_lru {
|
||||
use super::*;
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache> {
|
||||
/// Cache + lookup info.
|
||||
token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: C::Value,
|
||||
}
|
||||
|
||||
impl<C: Cache> Cached<C> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
|
||||
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
|
||||
Self {
|
||||
token: None,
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(&self) {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> Deref for Cached<C> {
|
||||
type Target = C::Value;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> DerefMut for Cached<C> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
@@ -9,14 +8,15 @@ use tracing::info;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
|
||||
pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
|
||||
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.lock()
|
||||
.read()
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("query cancellation key not found: {key}"))?;
|
||||
@@ -41,14 +41,14 @@ impl CancelMap {
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
self.0
|
||||
.lock()
|
||||
.write()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.lock().remove(&key);
|
||||
self.0.write().remove(&key);
|
||||
info!("dropped query cancellation key {key}");
|
||||
}
|
||||
|
||||
@@ -59,12 +59,12 @@ impl CancelMap {
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session) -> bool {
|
||||
self.0.lock().contains_key(&session.key)
|
||||
self.0.read().contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.lock().is_empty()
|
||||
self.0.read().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ impl Session<'_> {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
.write()
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
|
||||
@@ -42,14 +42,65 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
/// Construct a new connection config.
|
||||
pub fn new() -> Self {
|
||||
Self(Default::default())
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
if let Some(options) = params.options_raw() {
|
||||
// We must drop all proxy-specific parameters.
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = options
|
||||
.filter(|opt| !opt.starts_with("project="))
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
self.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
//
|
||||
// This and the reverse params problem can be better addressed
|
||||
// in a bespoke connection machinery (a new library for that sake).
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
@@ -132,50 +183,13 @@ pub struct PostgresConnection {
|
||||
pub stream: TcpStream,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
if let Some(options) = params.options_raw() {
|
||||
// We must drop all proxy-specific parameters.
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = options
|
||||
.filter(|opt| !opt.starts_with("project="))
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
self.0.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.0.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.0.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.0.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
//
|
||||
// This and the reverse params problem can be better addressed
|
||||
// in a bespoke connection machinery (a new library for that sake).
|
||||
|
||||
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
|
||||
// TODO: establish a secure connection to the DB.
|
||||
let (socket_addr, mut stream) = self.connect_raw().await?;
|
||||
let (client, connection) = self.0.connect_raw(&mut stream, NoTls).await?;
|
||||
@@ -189,8 +203,13 @@ impl ConnCfg {
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
let db = PostgresConnection { stream, params };
|
||||
|
||||
Ok((db, cancel_closure))
|
||||
let connection = PostgresConnection {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
use crate::auth;
|
||||
use anyhow::{ensure, Context};
|
||||
use std::sync::Arc;
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub metric_collection_config: Option<MetricCollectionConfig>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
}
|
||||
|
||||
pub struct MetricCollectionConfig {
|
||||
pub endpoint: reqwest::Url,
|
||||
pub interval: std::time::Duration,
|
||||
pub interval: Duration,
|
||||
}
|
||||
|
||||
pub struct TlsConfig {
|
||||
@@ -37,6 +37,7 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
@@ -73,3 +74,80 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
|
||||
common_name,
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
pub struct CacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl CacheOptions {
|
||||
/// Default options for [`crate::auth::caches::NodeInfoCache`].
|
||||
pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=5m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for CacheOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_cache_options() -> anyhow::Result<()> {
|
||||
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
|
||||
assert_eq!(size, 4096);
|
||||
assert_eq!(ttl, Duration::from_secs(5 * 60));
|
||||
|
||||
let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
|
||||
assert_eq!(size, 2);
|
||||
assert_eq!(ttl, Duration::from_secs(4 * 60));
|
||||
|
||||
let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
|
||||
assert_eq!(size, 0);
|
||||
assert_eq!(ttl, Duration::from_secs(1));
|
||||
|
||||
let CacheOptions { size, ttl } = "size=0".parse()?;
|
||||
assert_eq!(size, 0);
|
||||
assert_eq!(ttl, Duration::default());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,3 +3,15 @@
|
||||
|
||||
/// Payloads used in the console's APIs.
|
||||
pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
pub use super::provider::{ApiCaches, NodeInfoCache};
|
||||
}
|
||||
|
||||
/// Console's management API.
|
||||
pub mod mgmt;
|
||||
|
||||
@@ -63,13 +63,13 @@ impl KickSession<'_> {
|
||||
/// Compute node connection params.
|
||||
#[derive(Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub dbname: Box<str>,
|
||||
pub user: Box<str>,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub password: Option<String>,
|
||||
pub password: Option<Box<str>>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{
|
||||
auth,
|
||||
console::messages::{DatabaseInfo, KickSession},
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
@@ -14,6 +15,25 @@ use utils::{
|
||||
postgres_backend_async::QueryError,
|
||||
};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener thread.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
@@ -76,7 +96,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match auth::backend::notify(resp.session_id, Ok(resp.result)) {
|
||||
match notify(resp.session_id, Ok(resp.result)) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
194
proxy/src/console/provider.rs
Normal file
194
proxy/src/console/provider.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod errors {
|
||||
use crate::error::{io_error, UserFacingError};
|
||||
use reqwest::StatusCode as HttpStatusCode;
|
||||
use thiserror::Error;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
|
||||
Console {
|
||||
status: HttpStatusCode,
|
||||
text: Box<str>,
|
||||
},
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub fn http_status_code(&self) -> Option<HttpStatusCode> {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
Console { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
Console { status, .. } => match *status {
|
||||
HttpStatusCode::NOT_FOUND => {
|
||||
// Status 404: failed to get a project-related resource.
|
||||
format!("{REQUEST_FAILED}: endpoint cannot be found")
|
||||
}
|
||||
HttpStatusCode::NOT_ACCEPTABLE => {
|
||||
// Status 406: endpoint is disabled (we don't allow connections).
|
||||
format!("{REQUEST_FAILED}: endpoint is disabled")
|
||||
}
|
||||
HttpStatusCode::LOCKED => {
|
||||
// Status 423: project might be in maintenance mode (or bad state).
|
||||
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
},
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra<'a> {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: Arc<MetricsAuxInfo>,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
#[async_trait]
|
||||
pub trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
/// Various caches for [`console`].
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
}
|
||||
@@ -1,21 +1,14 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use super::{
|
||||
console::{self, AuthInfo, GetAuthInfoError, WakeComputeError},
|
||||
AuthSuccess, NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
compute,
|
||||
error::io_error,
|
||||
scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
@@ -23,49 +16,36 @@ enum MockApiError {
|
||||
PasswordNotSet(tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl From<MockApiError> for console::ApiError {
|
||||
impl From<MockApiError> for ApiError {
|
||||
fn from(e: MockApiError) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for console::ApiError {
|
||||
impl From<tokio_postgres::Error> for ApiError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: ApiUrl,
|
||||
}
|
||||
|
||||
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
|
||||
fn as_ref(&self) -> &ClientCredentials<'a> {
|
||||
self.creds
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
|
||||
Self { endpoint, creds }
|
||||
impl Api {
|
||||
pub fn new(endpoint: ApiUrl) -> Self {
|
||||
Self { endpoint }
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
pub(super) async fn handle_user(
|
||||
&'a self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
// We reuse user handling logic from a production module.
|
||||
console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl Api<'_> {
|
||||
/// This implementation fetches the auth info from a local postgres instance.
|
||||
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
@@ -75,7 +55,7 @@ impl Api<'_> {
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client.query(query, &[&self.creds.user]).await?;
|
||||
let rows = client.query(query, &[&creds.user]).await?;
|
||||
|
||||
// We can get at most one row, because `rolname` is unique.
|
||||
let row = match rows.get(0) {
|
||||
@@ -84,7 +64,7 @@ impl Api<'_> {
|
||||
// However, this is still a *valid* outcome which is very similar
|
||||
// to getting `404 Not found` from the Neon console.
|
||||
None => {
|
||||
warn!("user '{}' does not exist", self.creds.user);
|
||||
warn!("user '{}' does not exist", creds.user);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
@@ -98,23 +78,50 @@ impl Api<'_> {
|
||||
Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("get_auth_info", mock = self.endpoint.as_str()))
|
||||
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
|
||||
.await
|
||||
}
|
||||
|
||||
/// We don't need to wake anything locally, so we just return the connection info.
|
||||
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
.dbname(creds.dbname)
|
||||
.user(creds.user);
|
||||
|
||||
Ok(NodeInfo {
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: Default::default(),
|
||||
})
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(creds).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute(creds)
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
196
proxy/src/console/provider/neon.rs
Normal file
196
proxy/src/console/provider/neon.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
//! Production console backend.
|
||||
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use reqwest::StatusCode as HttpStatusCode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
|
||||
Self { endpoint, caches }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(creds.project().expect("impossible"))),
|
||||
("role", Some(creds.user)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = match parse_body::<GetRoleSecret>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
Err(e) => match e.http_status_code() {
|
||||
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
|
||||
_otherwise => return Err(e.into()),
|
||||
},
|
||||
};
|
||||
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
|
||||
Ok(Some(secret))
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let project = creds.project().expect("impossible");
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(project)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(creds.dbname)
|
||||
.user(creds.user);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux.into(),
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(extra, creds).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = creds.project().expect("impossible");
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), node);
|
||||
info!(key = key, "created a cache entry for compute node info");
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: reqwest::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
}
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let body = response.json().await.unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ConsoleError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
}
|
||||
});
|
||||
|
||||
let text = body.error;
|
||||
error!("console responded with an error ({status}): {text}");
|
||||
Err(ApiError::Console { status, text })
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,7 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let router = endpoint::make_router();
|
||||
router.get("/v1/status", status_handler)
|
||||
endpoint::make_router().get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hyper::server::accept::{self};
|
||||
use hyper::server::accept;
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
@@ -161,7 +161,7 @@ impl AsyncBufRead for WebSocketRW {
|
||||
|
||||
async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &ProxyConfig,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
|
||||
mod auth;
|
||||
mod cache;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
mod config;
|
||||
@@ -12,7 +13,6 @@ mod console;
|
||||
mod error;
|
||||
mod http;
|
||||
mod metrics;
|
||||
mod mgmt;
|
||||
mod parse;
|
||||
mod proxy;
|
||||
mod sasl;
|
||||
@@ -21,7 +21,6 @@ mod stream;
|
||||
mod url;
|
||||
mod waiters;
|
||||
|
||||
use ::metrics::set_build_info_metric;
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{self, Arg};
|
||||
use config::ProxyConfig;
|
||||
@@ -29,8 +28,7 @@ use futures::FutureExt;
|
||||
use std::{borrow::Cow, future::Future, net::SocketAddr};
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
@@ -51,124 +49,133 @@ async fn main() -> anyhow::Result<()> {
|
||||
// initialize sentry if SENTRY_DSN is provided
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
let arg_matches = cli().get_matches();
|
||||
|
||||
let tls_config = match (
|
||||
arg_matches.get_one::<String>("tls-key"),
|
||||
arg_matches.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
let proxy_address: SocketAddr = arg_matches.get_one::<String>("proxy").unwrap().parse()?;
|
||||
let mgmt_address: SocketAddr = arg_matches.get_one::<String>("mgmt").unwrap().parse()?;
|
||||
let http_address: SocketAddr = arg_matches.get_one::<String>("http").unwrap().parse()?;
|
||||
|
||||
let metric_collection_config = match
|
||||
(
|
||||
arg_matches.get_one::<String>("metric-collection-endpoint"),
|
||||
arg_matches.get_one::<String>("metric-collection-interval"),
|
||||
) {
|
||||
|
||||
(Some(endpoint), Some(interval)) => {
|
||||
Some(config::MetricCollectionConfig {
|
||||
endpoint: endpoint.parse()?,
|
||||
interval: humantime::parse_duration(interval)?,
|
||||
})
|
||||
}
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither metric-collection-endpoint and metric-collection-interval must be specified"),
|
||||
};
|
||||
|
||||
let auth_backend = match arg_matches
|
||||
.get_one::<String>("auth-backend")
|
||||
.unwrap()
|
||||
.as_str()
|
||||
{
|
||||
"console" => {
|
||||
let url = arg_matches
|
||||
.get_one::<String>("auth-endpoint")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, reqwest::Client::new());
|
||||
auth::BackendType::Console(Cow::Owned(endpoint), ())
|
||||
}
|
||||
"postgres" => {
|
||||
let url = arg_matches
|
||||
.get_one::<String>("auth-endpoint")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
auth::BackendType::Postgres(Cow::Owned(url), ())
|
||||
}
|
||||
"link" => {
|
||||
let url = arg_matches.get_one::<String>("uri").unwrap().parse()?;
|
||||
auth::BackendType::Link(Cow::Owned(url))
|
||||
}
|
||||
other => bail!("unsupported auth backend: {other}"),
|
||||
};
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection_config,
|
||||
}));
|
||||
|
||||
info!("Version: {GIT_VERSION}");
|
||||
::metrics::set_build_info_metric(GIT_VERSION);
|
||||
|
||||
let args = cli().get_matches();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
let http_address: SocketAddr = args.get_one::<String>("http").unwrap().parse()?;
|
||||
info!("Starting http on {http_address}");
|
||||
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
|
||||
|
||||
let mgmt_address: SocketAddr = args.get_one::<String>("mgmt").unwrap().parse()?;
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?;
|
||||
|
||||
let proxy_address: SocketAddr = args.get_one::<String>("proxy").unwrap().parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(proxy::task_main(config, proxy_listener)),
|
||||
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
|
||||
tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)),
|
||||
];
|
||||
|
||||
if let Some(wss_address) = arg_matches.get_one::<String>("wss") {
|
||||
if let Some(wss_address) = args.get_one::<String>("wss") {
|
||||
let wss_address: SocketAddr = wss_address.parse()?;
|
||||
info!("Starting wss on {}", wss_address);
|
||||
info!("Starting wss on {wss_address}");
|
||||
let wss_listener = TcpListener::bind(wss_address).await?;
|
||||
|
||||
tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
wss_listener,
|
||||
config,
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(metric_collection_config) = &config.metric_collection_config {
|
||||
// TODO: refactor.
|
||||
if let Some(metric_collection) = &config.metric_collection {
|
||||
let hostname = hostname::get()?
|
||||
.into_string()
|
||||
.map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?;
|
||||
|
||||
tasks.push(tokio::spawn(
|
||||
metrics::collect_metrics(
|
||||
&metric_collection_config.endpoint,
|
||||
metric_collection_config.interval,
|
||||
&metric_collection.endpoint,
|
||||
metric_collection.interval,
|
||||
hostname,
|
||||
)
|
||||
.instrument(info_span!("collect_metrics")),
|
||||
));
|
||||
}
|
||||
|
||||
let tasks = tasks.into_iter().map(flatten_err);
|
||||
|
||||
set_build_info_metric(GIT_VERSION);
|
||||
// This will block until all tasks have completed.
|
||||
// Furthermore, the first one to fail will cancel the rest.
|
||||
let tasks = tasks.into_iter().map(flatten_err);
|
||||
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let tls_config = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
let metric_collection = match (
|
||||
args.get_one::<String>("metric-collection-endpoint"),
|
||||
args.get_one::<String>("metric-collection-interval"),
|
||||
) {
|
||||
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
|
||||
endpoint: endpoint.parse()?,
|
||||
interval: humantime::parse_duration(interval)?,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => bail!(
|
||||
"either both or neither metric-collection-endpoint \
|
||||
and metric-collection-interval must be specified"
|
||||
),
|
||||
};
|
||||
|
||||
let auth_backend = match args.get_one::<String>("auth-backend").unwrap().as_str() {
|
||||
"console" => {
|
||||
let config::CacheOptions { size, ttl } = args
|
||||
.get_one::<String>("wake-compute-cache")
|
||||
.unwrap()
|
||||
.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches {
|
||||
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
|
||||
}));
|
||||
|
||||
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
|
||||
let endpoint = http::Endpoint::new(url, reqwest::Client::new());
|
||||
|
||||
let api = console::provider::neon::Api::new(endpoint, caches);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
"postgres" => {
|
||||
let url = args.get_one::<String>("auth-endpoint").unwrap().parse()?;
|
||||
let api = console::provider::mock::Api::new(url);
|
||||
auth::BackendType::Postgres(Cow::Owned(api), ())
|
||||
}
|
||||
"link" => {
|
||||
let url = args.get_one::<String>("uri").unwrap().parse()?;
|
||||
auth::BackendType::Link(Cow::Owned(url))
|
||||
}
|
||||
other => bail!("unsupported auth backend: {other}"),
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn cli() -> clap::Command {
|
||||
clap::Command::new("Neon proxy/router")
|
||||
.disable_help_flag(true)
|
||||
@@ -235,16 +242,27 @@ fn cli() -> clap::Command {
|
||||
.arg(
|
||||
Arg::new("metric-collection-endpoint")
|
||||
.long("metric-collection-endpoint")
|
||||
.help("metric collection HTTP endpoint"),
|
||||
.help("http endpoint to receive periodic metric updates"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("metric-collection-interval")
|
||||
.long("metric-collection-interval")
|
||||
.help("metric collection interval"),
|
||||
.help("how often metrics should be sent to a collection endpoint"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("wake-compute-cache")
|
||||
.long("wake-compute-cache")
|
||||
.help("cache for `wake_compute` api method (use `size=0` to disable)")
|
||||
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_cli() {
|
||||
cli().debug_assert();
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn verify_cli() {
|
||||
cli().debug_assert();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
mod tests;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
auth::{self, backend::AuthSuccess},
|
||||
cancellation::{self, CancelMap},
|
||||
compute::{self, PostgresConnection},
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
error::io_error,
|
||||
stream::{MeasuredStream, PqStream, Stream},
|
||||
};
|
||||
use anyhow::{bail, Context};
|
||||
@@ -14,7 +17,10 @@ use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
/// Number of times we should retry the `/proxy_wake_compute` http request.
|
||||
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||
@@ -35,6 +41,15 @@ static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounter> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_connection_failures_total",
|
||||
"Number of connection failures (per kind).",
|
||||
&["kind"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_io_bytes_per_client",
|
||||
@@ -82,11 +97,12 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(tech debt): unite this with its twin below.
|
||||
pub async fn handle_ws_client(
|
||||
config: &ProxyConfig,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
@@ -99,7 +115,7 @@ pub async fn handle_ws_client(
|
||||
let hostname = hostname.as_deref();
|
||||
|
||||
// TLS is None here, because the connection is already encrypted.
|
||||
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
|
||||
let do_handshake = handshake(stream, None, cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
@@ -124,10 +140,10 @@ pub async fn handle_ws_client(
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
@@ -136,7 +152,7 @@ async fn handle_client(
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let do_handshake = handshake(stream, tls, cancel_map).instrument(info_span!("handshake"));
|
||||
let do_handshake = handshake(stream, tls, cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
@@ -165,6 +181,7 @@ async fn handle_client(
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
@@ -226,6 +243,133 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
// If we couldn't connect, a cached connection info might be to blame
|
||||
// (e.g. the compute node's address might've changed at the wrong time).
|
||||
// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
let invalidate_cache = |_: &compute::ConnectionError| {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
node_info.invalidate();
|
||||
}
|
||||
|
||||
let label = match is_cached {
|
||||
true => "compute_cached",
|
||||
false => "compute_uncached",
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
};
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect()
|
||||
.inspect_err(invalidate_cache)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn connect_to_compute(
|
||||
node_info: &mut console::CachedNodeInfo,
|
||||
params: &StartupMessageParams,
|
||||
extra: &console::ConsoleReqExtra<'_>,
|
||||
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE;
|
||||
loop {
|
||||
// Apply startup params to the (possibly, cached) compute node info.
|
||||
node_info.config.set_startup_params(params);
|
||||
match connect_to_compute_once(node_info).await {
|
||||
Err(e) if num_retries > 0 => {
|
||||
info!("compute node's state has changed; requesting a wake-up");
|
||||
match creds.wake_compute(extra).map_err(io_error).await? {
|
||||
// Update `node_info` and try one more time.
|
||||
Some(mut new) => {
|
||||
new.config.reuse_password(&node_info.config);
|
||||
*node_info = new;
|
||||
}
|
||||
// Link auth doesn't work that way, so we just exit.
|
||||
None => return Err(e),
|
||||
}
|
||||
}
|
||||
other => return other,
|
||||
}
|
||||
|
||||
num_retries -= 1;
|
||||
info!("retrying after wake-up ({num_retries} attempts left)");
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
reported_auth_ok: bool,
|
||||
session: cancellation::Session<'_>,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
// Note that we do this only (for the most part) after we've connected
|
||||
// to a compute (see above) which performs its own authentication.
|
||||
if !reported_auth_ok {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
// Right now the implementation is very hacky and inefficent (ideally,
|
||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||
for (name, value) in &node.params {
|
||||
// TODO: Theoretically, this could result in a big pile of params...
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: &MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(client, |cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
||||
let mut compute = MeasuredStream::new(compute, |cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
@@ -255,17 +399,17 @@ impl<'a, S> Client<'a, S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
mut creds,
|
||||
params,
|
||||
session_id,
|
||||
} = self;
|
||||
|
||||
let extra = auth::ConsoleReqExtra {
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
application_name: params.get("application_name"),
|
||||
};
|
||||
@@ -278,54 +422,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
.instrument(info_span!("auth"))
|
||||
.await?;
|
||||
|
||||
let node = auth_result.value;
|
||||
let (db, cancel_closure) = node
|
||||
.config
|
||||
.connect(params)
|
||||
let AuthSuccess {
|
||||
reported_auth_ok,
|
||||
value: mut node_info,
|
||||
} = auth_result;
|
||||
|
||||
let node = connect_to_compute(&mut node_info, params, &extra, &creds)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
// Note that we do this only (for the most part) after we've connected
|
||||
// to a compute (see above) which performs its own authentication.
|
||||
if !auth_result.reported_auth_ok {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
// Right now the implementation is very hacky and inefficent (ideally,
|
||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||
for (name, value) in &db.params {
|
||||
// TODO: Theoretically, this could result in a big pile of params...
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
|
||||
let mut db = MeasuredStream::new(db.stream, |cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
|
||||
Ok(())
|
||||
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
|
||||
proxy_pass(stream.into_inner(), node.stream, &node_info.aux).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ futures-channel = { version = "0.3", features = ["sink"] }
|
||||
futures-executor = { version = "0.3" }
|
||||
futures-task = { version = "0.3", default-features = false, features = ["std"] }
|
||||
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
||||
hashbrown = { version = "0.12", features = ["raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
itertools = { version = "0.10" }
|
||||
libc = { version = "0.2", features = ["extra_traits"] }
|
||||
@@ -58,6 +59,7 @@ url = { version = "2", features = ["serde"] }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
either = { version = "1" }
|
||||
hashbrown = { version = "0.12", features = ["raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
itertools = { version = "0.10" }
|
||||
libc = { version = "0.2", features = ["extra_traits"] }
|
||||
|
||||
Reference in New Issue
Block a user