diff --git a/Cargo.lock b/Cargo.lock index 794ec25bf7..9be9af35e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.2" @@ -1515,6 +1526,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.6", +] [[package]] name = "hashbrown" @@ -1522,7 +1536,16 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ - "ahash", + "ahash 0.8.2", +] + +[[package]] +name = "hashlink" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa" +dependencies = [ + "hashbrown 0.12.3", ] [[package]] @@ -2764,6 +2787,7 @@ dependencies = [ "futures", "git-version", "hashbrown 0.13.2", + "hashlink", "hex", "hmac", "hostname", @@ -4624,6 +4648,7 @@ dependencies = [ "futures-executor", "futures-task", "futures-util", + "hashbrown 0.12.3", "indexmap", "itertools", "libc", diff --git a/Cargo.toml b/Cargo.toml index e6695c4246..9dcf1a265a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ futures-core = "0.3" futures-util = "0.3" git-version = "0.3" hashbrown = "0.13" +hashlink = "0.8.1" hex = "0.4" hex-literal = "0.3" hmac = "0.12.1" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 03a6ddac5d..1ff7eebd98 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,58 +6,59 @@ license.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true atty.workspace = true base64.workspace = true bstr.workspace = true -bytes = {workspace = true, features = ['serde'] } -clap.workspace = true +bytes = { workspace = true, features = ["serde"] } chrono.workspace = true +clap.workspace = true consumption_metrics.workspace = true futures.workspace = true git-version.workspace = true hashbrown.workspace = true +hashlink.workspace = true hex.workspace = true hmac.workspace = true -hyper.workspace = true +hostname.workspace = true +humantime.workspace = true hyper-tungstenite.workspace = true +hyper.workspace = true itertools.workspace = true md5.workspace = true +metrics.workspace = true once_cell.workspace = true parking_lot.workspace = true pin-project-lite.workspace = true +pq_proto.workspace = true +prometheus.workspace = true rand.workspace = true regex.workspace = true -reqwest = { workspace = true, features = [ "json" ] } +reqwest = { workspace = true, features = ["json"] } routerify.workspace = true -rustls.workspace = true rustls-pemfile.workspace = true +rustls.workspace = true scopeguard.workspace = true serde.workspace = true serde_json.workspace = true sha2.workspace = true socket2.workspace = true thiserror.workspace = true -tokio.workspace = true +tls-listener.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true -tls-listener.workspace = true -tracing.workspace = true +tokio.workspace = true tracing-subscriber.workspace = true +tracing.workspace = true url.workspace = true +utils.workspace = true uuid.workspace = true webpki-roots.workspace = true x509-parser.workspace = true -metrics.workspace = true -pq_proto.workspace = true -utils.workspace = true -prometheus.workspace = true -humantime.workspace = true -hostname.workspace = true workspace_hack.workspace = true [dev-dependencies] -async-trait.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 0446b53603..dfea84953b 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,7 +1,7 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::{BackendType, ConsoleReqExtra}; +pub use backend::BackendType; mod credentials; pub use credentials::ClientCredentials; @@ -12,7 +12,7 @@ use password_hack::PasswordHackPayload; mod flow; pub use flow::*; -use crate::error::UserFacingError; +use crate::{console, error::UserFacingError}; use std::io; use thiserror::Error; @@ -26,10 +26,10 @@ pub enum AuthErrorImpl { Link(#[from] backend::LinkAuthError), #[error(transparent)] - GetAuthInfo(#[from] backend::GetAuthInfoError), + GetAuthInfo(#[from] console::errors::GetAuthInfoError), #[error(transparent)] - WakeCompute(#[from] backend::WakeComputeError), + WakeCompute(#[from] console::errors::WakeComputeError), /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index e6a179a040..60460e6722 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -1,48 +1,40 @@ -mod postgres; +mod classic; mod link; +use futures::TryFutureExt; pub use link::LinkAuthError; -mod console; -pub use console::{GetAuthInfoError, WakeComputeError}; - use crate::{ auth::{self, AuthFlow, ClientCredentials}, - compute, - console::messages::MetricsAuxInfo, - http, mgmt, stream, url, - waiters::{self, Waiter, Waiters}, + console::{ + self, + provider::{CachedNodeInfo, ConsoleReqExtra}, + Api, + }, + stream, url, }; -use once_cell::sync::Lazy; use std::borrow::Cow; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; -static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); - -/// Give caller an opportunity to wait for the cloud's reply. -pub async fn with_waiter( - psql_session_id: impl Into, - action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, -) -> Result -where - R: std::future::Future>, - E: From, -{ - let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; - action(waiter).await +/// A product of successful authentication. +pub struct AuthSuccess { + /// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client? + pub reported_auth_ok: bool, + /// Something to be considered a positive result. + pub value: T, } -pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> { - CPLANE_WAITERS.notify(psql_session_id, msg) -} - -/// Extra query params we'd like to pass to the console. -pub struct ConsoleReqExtra<'a> { - /// A unique identifier for a connection. - pub session_id: uuid::Uuid, - /// Name of client application, if set. - pub application_name: Option<&'a str>, +impl AuthSuccess { + /// Very similar to [`std::option::Option::map`]. + /// Maps [`AuthSuccess`] to [`AuthSuccess`] by applying + /// a function to a contained value. + pub fn map(self, f: impl FnOnce(T) -> R) -> AuthSuccess { + AuthSuccess { + reported_auth_ok: self.reported_auth_ok, + value: f(self.value), + } + } } /// This type serves two purposes: @@ -53,12 +45,11 @@ pub struct ConsoleReqExtra<'a> { /// * However, when we substitute `T` with [`ClientCredentials`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -#[derive(Debug)] pub enum BackendType<'a, T> { /// Current Cloud API (V2). - Console(Cow<'a, http::Endpoint>, T), + Console(Cow<'a, console::provider::neon::Api>, T), /// Local mock of Cloud API (V2). - Postgres(Cow<'a, url::ApiUrl>, T), + Postgres(Cow<'a, console::provider::mock::Api>, T), /// Authentication via a web browser. Link(Cow<'a, url::ApiUrl>), } @@ -67,14 +58,8 @@ impl std::fmt::Display for BackendType<'_, ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { - Console(endpoint, _) => fmt - .debug_tuple("Console") - .field(&endpoint.url().as_str()) - .finish(), - Postgres(endpoint, _) => fmt - .debug_tuple("Postgres") - .field(&endpoint.as_str()) - .finish(), + Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(), + Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(), Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), } } @@ -120,30 +105,16 @@ impl<'a, T, E> BackendType<'a, Result> { } } -/// A product of successful authentication. -pub struct AuthSuccess { - /// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client? - pub reported_auth_ok: bool, - /// Something to be considered a positive result. - pub value: T, -} - -/// Info for establishing a connection to a compute node. -/// This is what we get after auth succeeded, but not before! -pub struct NodeInfo { - /// Compute node connection params. - pub config: compute::ConnCfg, - /// Labels for proxy's metrics. - pub aux: MetricsAuxInfo, -} - -impl BackendType<'_, ClientCredentials<'_>> { +// TODO: get rid of explicit lifetimes in this block (there's a bug in rustc). +// Read more: https://github.com/rust-lang/rust/issues/99190 +// Alleged fix: https://github.com/rust-lang/rust/pull/89056 +impl<'l> BackendType<'l, ClientCredentials<'_>> { /// Do something special if user didn't provide the `project` parameter. - async fn try_password_hack( - &mut self, - extra: &ConsoleReqExtra<'_>, - client: &mut stream::PqStream, - ) -> auth::Result>> { + async fn try_password_hack<'a>( + &'a mut self, + extra: &'a ConsoleReqExtra<'a>, + client: &'a mut stream::PqStream, + ) -> auth::Result>> { use BackendType::*; // If there's no project so far, that entails that client doesn't @@ -179,33 +150,28 @@ impl BackendType<'_, ClientCredentials<'_>> { // TODO: find a proper way to merge those very similar blocks. let (mut node, payload) = match self { - Console(endpoint, creds) if creds.project.is_none() => { + Console(api, creds) if creds.project.is_none() => { let payload = fetch_magic_payload(client).await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); - let node = console::Api::new(endpoint, extra, &creds) - .wake_compute() - .await?; + let node = api.wake_compute(extra, &creds).await?; (node, payload) } - Console(endpoint, creds) if creds.use_cleartext_password_flow => { - // This is a hack to allow cleartext password in secure connections (wss). + // This is a hack to allow cleartext password in secure connections (wss). + Console(api, creds) if creds.use_cleartext_password_flow => { let payload = fetch_plaintext_password(client).await?; - let creds = creds.as_ref(); - let node = console::Api::new(endpoint, extra, &creds) - .wake_compute() - .await?; + let node = api.wake_compute(extra, creds).await?; (node, payload) } - Postgres(endpoint, creds) if creds.project.is_none() => { + Postgres(api, creds) if creds.project.is_none() => { let payload = fetch_magic_payload(client).await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); - let node = postgres::Api::new(endpoint, &creds).wake_compute().await?; + let node = api.wake_compute(extra, &creds).await?; (node, payload) } @@ -220,11 +186,11 @@ impl BackendType<'_, ClientCredentials<'_>> { } /// Authenticate the client via the requested backend, possibly using credentials. - pub async fn authenticate( - mut self, - extra: &ConsoleReqExtra<'_>, - client: &mut stream::PqStream, - ) -> auth::Result> { + pub async fn authenticate<'a>( + &mut self, + extra: &'a ConsoleReqExtra<'a>, + client: &'a mut stream::PqStream, + ) -> auth::Result> { use BackendType::*; // Handle cases when `project` is missing in `creds`. @@ -235,7 +201,7 @@ impl BackendType<'_, ClientCredentials<'_>> { } let res = match self { - Console(endpoint, creds) => { + Console(api, creds) => { info!( user = creds.user, project = creds.project(), @@ -243,26 +209,40 @@ impl BackendType<'_, ClientCredentials<'_>> { ); assert!(creds.project.is_some()); - console::Api::new(&endpoint, extra, &creds) - .handle_user(client) - .await? + classic::handle_user(api.as_ref(), extra, creds, client).await? } - Postgres(endpoint, creds) => { + Postgres(api, creds) => { info!("performing mock authentication using a local postgres instance"); assert!(creds.project.is_some()); - postgres::Api::new(&endpoint, &creds) - .handle_user(client) - .await? + classic::handle_user(api.as_ref(), extra, creds, client).await? } // NOTE: this auth backend doesn't use client credentials. Link(url) => { info!("performing link authentication"); - link::handle_user(&url, client).await? + + link::handle_user(url, client) + .await? + .map(CachedNodeInfo::new_uncached) } }; info!("user successfully authenticated"); Ok(res) } + + /// When applicable, wake the compute node, gaining its connection info in the process. + /// The link auth flow doesn't support this, so we return [`None`] in that case. + pub async fn wake_compute<'a>( + &self, + extra: &'a ConsoleReqExtra<'a>, + ) -> Result, console::errors::WakeComputeError> { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await, + Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await, + Link(_) => Ok(None), + } + } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs new file mode 100644 index 0000000000..eefef6e9b4 --- /dev/null +++ b/proxy/src/auth/backend/classic.rs @@ -0,0 +1,61 @@ +use super::AuthSuccess; +use crate::{ + auth::{self, AuthFlow, ClientCredentials}, + compute, + console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra}, + sasl, scram, + stream::PqStream, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::info; + +pub(super) async fn handle_user( + api: &impl console::Api, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + client: &mut PqStream, +) -> auth::Result> { + info!("fetching user's authentication info"); + let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) + }); + + let flow = AuthFlow::new(client); + let scram_keys = match info { + AuthInfo::Md5(_) => { + info!("auth endpoint chooses MD5"); + return Err(auth::AuthError::bad_auth_method("MD5")); + } + AuthInfo::Scram(secret) => { + info!("auth endpoint chooses SCRAM"); + let scram = auth::Scram(&secret); + let client_key = match flow.begin(scram).await?.authenticate().await? { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(creds.user)); + } + }; + + Some(compute::ScramKeys { + client_key: client_key.as_bytes(), + server_key: secret.server_key.as_bytes(), + }) + } + }; + + let mut node = api.wake_compute(extra, creds).await?; + if let Some(keys) = scram_keys { + use tokio_postgres::config::AuthKeys; + node.config.auth_keys(AuthKeys::ScramSha256(keys)); + } + + Ok(AuthSuccess { + reported_auth_ok: false, + value: node, + }) +} diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs deleted file mode 100644 index b3e3fd0c10..0000000000 --- a/proxy/src/auth/backend/console.rs +++ /dev/null @@ -1,365 +0,0 @@ -//! Cloud API V2. - -use super::{AuthSuccess, ConsoleReqExtra, NodeInfo}; -use crate::{ - auth::{self, AuthFlow, ClientCredentials}, - compute, - console::messages::{ConsoleError, GetRoleSecret, WakeCompute}, - error::{io_error, UserFacingError}, - http, sasl, scram, - stream::PqStream, -}; -use futures::TryFutureExt; -use reqwest::StatusCode as HttpStatusCode; -use std::future::Future; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, info_span, warn, Instrument}; - -/// A go-to error message which doesn't leak any detail. -const REQUEST_FAILED: &str = "Console request failed"; - -/// Common console API error. -#[derive(Debug, Error)] -pub enum ApiError { - /// Error returned by the console itself. - #[error("{REQUEST_FAILED} with {}: {}", .status, .text)] - Console { - status: HttpStatusCode, - text: Box, - }, - - /// Various IO errors like broken pipe or malformed payload. - #[error("{REQUEST_FAILED}: {0}")] - Transport(#[from] std::io::Error), -} - -impl ApiError { - /// Returns HTTP status code if it's the reason for failure. - fn http_status_code(&self) -> Option { - use ApiError::*; - match self { - Console { status, .. } => Some(*status), - _ => None, - } - } -} - -impl UserFacingError for ApiError { - fn to_string_client(&self) -> String { - use ApiError::*; - match self { - // To minimize risks, only select errors are forwarded to users. - // Ask @neondatabase/control-plane for review before adding more. - Console { status, .. } => match *status { - HttpStatusCode::NOT_FOUND => { - // Status 404: failed to get a project-related resource. - format!("{REQUEST_FAILED}: endpoint cannot be found") - } - HttpStatusCode::NOT_ACCEPTABLE => { - // Status 406: endpoint is disabled (we don't allow connections). - format!("{REQUEST_FAILED}: endpoint is disabled") - } - HttpStatusCode::LOCKED => { - // Status 423: project might be in maintenance mode (or bad state). - format!("{REQUEST_FAILED}: endpoint is temporary unavailable") - } - _ => REQUEST_FAILED.to_owned(), - }, - _ => REQUEST_FAILED.to_owned(), - } - } -} - -// Helps eliminate graceless `.map_err` calls without introducing another ctor. -impl From for ApiError { - fn from(e: reqwest::Error) -> Self { - io_error(e).into() - } -} - -#[derive(Debug, Error)] -pub enum GetAuthInfoError { - // We shouldn't include the actual secret here. - #[error("Console responded with a malformed auth secret")] - BadSecret, - - #[error(transparent)] - ApiError(ApiError), -} - -// This allows more useful interactions than `#[from]`. -impl> From for GetAuthInfoError { - fn from(e: E) -> Self { - Self::ApiError(e.into()) - } -} - -impl UserFacingError for GetAuthInfoError { - fn to_string_client(&self) -> String { - use GetAuthInfoError::*; - match self { - // We absolutely should not leak any secrets! - BadSecret => REQUEST_FAILED.to_owned(), - // However, API might return a meaningful error. - ApiError(e) => e.to_string_client(), - } - } -} - -#[derive(Debug, Error)] -pub enum WakeComputeError { - #[error("Console responded with a malformed compute address: {0}")] - BadComputeAddress(Box), - - #[error(transparent)] - ApiError(ApiError), -} - -// This allows more useful interactions than `#[from]`. -impl> From for WakeComputeError { - fn from(e: E) -> Self { - Self::ApiError(e.into()) - } -} - -impl UserFacingError for WakeComputeError { - fn to_string_client(&self) -> String { - use WakeComputeError::*; - match self { - // We shouldn't show user the address even if it's broken. - // Besides, user is unlikely to care about this detail. - BadComputeAddress(_) => REQUEST_FAILED.to_owned(), - // However, API might return a meaningful error. - ApiError(e) => e.to_string_client(), - } - } -} - -/// Auth secret which is managed by the cloud. -pub enum AuthInfo { - /// Md5 hash of user's password. - Md5([u8; 16]), - - /// [SCRAM](crate::scram) authentication info. - Scram(scram::ServerSecret), -} - -#[must_use] -pub(super) struct Api<'a> { - endpoint: &'a http::Endpoint, - extra: &'a ConsoleReqExtra<'a>, - creds: &'a ClientCredentials<'a>, -} - -impl<'a> AsRef> for Api<'a> { - fn as_ref(&self) -> &ClientCredentials<'a> { - self.creds - } -} - -impl<'a> Api<'a> { - /// Construct an API object containing the auth parameters. - pub(super) fn new( - endpoint: &'a http::Endpoint, - extra: &'a ConsoleReqExtra<'a>, - creds: &'a ClientCredentials, - ) -> Self { - Self { - endpoint, - extra, - creds, - } - } - - /// Authenticate the existing user or throw an error. - pub(super) async fn handle_user( - &'a self, - client: &mut PqStream, - ) -> auth::Result> { - handle_user(client, self, Self::get_auth_info, Self::wake_compute).await - } -} - -impl Api<'_> { - async fn get_auth_info(&self) -> Result, GetAuthInfoError> { - let request_id = uuid::Uuid::new_v4().to_string(); - async { - let request = self - .endpoint - .get("proxy_get_role_secret") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ("role", Some(self.creds.user)), - ]) - .build()?; - - info!(url = request.url().as_str(), "sending http request"); - let response = self.endpoint.execute(request).await?; - let body = match parse_body::(response).await { - Ok(body) => body, - // Error 404 is special: it's ok not to have a secret. - Err(e) => match e.http_status_code() { - Some(HttpStatusCode::NOT_FOUND) => return Ok(None), - _otherwise => return Err(e.into()), - }, - }; - - let secret = scram::ServerSecret::parse(&body.role_secret) - .map(AuthInfo::Scram) - .ok_or(GetAuthInfoError::BadSecret)?; - - Ok(Some(secret)) - } - .map_err(crate::error::log_error) - .instrument(info_span!("get_auth_info", id = request_id)) - .await - } - - /// Wake up the compute node and return the corresponding connection info. - pub async fn wake_compute(&self) -> Result { - let request_id = uuid::Uuid::new_v4().to_string(); - async { - let request = self - .endpoint - .get("proxy_wake_compute") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ]) - .build()?; - - info!(url = request.url().as_str(), "sending http request"); - let response = self.endpoint.execute(request).await?; - let body = parse_body::(response).await?; - - // Unfortunately, ownership won't let us use `Option::ok_or` here. - let (host, port) = match parse_host_port(&body.address) { - None => return Err(WakeComputeError::BadComputeAddress(body.address)), - Some(x) => x, - }; - - let mut config = compute::ConnCfg::new(); - config - .host(host) - .port(port) - .dbname(self.creds.dbname) - .user(self.creds.user); - - Ok(NodeInfo { - config, - aux: body.aux, - }) - } - .map_err(crate::error::log_error) - .instrument(info_span!("wake_compute", id = request_id)) - .await - } -} - -/// Common logic for user handling in API V2. -/// We reuse this for a mock API implementation in [`super::postgres`]. -pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( - client: &mut PqStream, - endpoint: &'a Endpoint, - get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo, - wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, -) -> auth::Result> -where - Endpoint: AsRef>, - GetAuthInfo: Future, GetAuthInfoError>>, - WakeCompute: Future>, -{ - let creds = endpoint.as_ref(); - - info!("fetching user's authentication info"); - let info = get_auth_info(endpoint).await?.unwrap_or_else(|| { - // If we don't have an authentication secret, we mock one to - // prevent malicious probing (possible due to missing protocol steps). - // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); - AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) - }); - - let flow = AuthFlow::new(client); - let scram_keys = match info { - AuthInfo::Md5(_) => { - info!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); - } - AuthInfo::Scram(secret) => { - info!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret); - let client_key = match flow.begin(scram).await?.authenticate().await? { - sasl::Outcome::Success(key) => key, - sasl::Outcome::Failure(reason) => { - info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(creds.user)); - } - }; - - Some(compute::ScramKeys { - client_key: client_key.as_bytes(), - server_key: secret.server_key.as_bytes(), - }) - } - }; - - let mut node = wake_compute(endpoint).await?; - if let Some(keys) = scram_keys { - use tokio_postgres::config::AuthKeys; - node.config.auth_keys(AuthKeys::ScramSha256(keys)); - } - - Ok(AuthSuccess { - reported_auth_ok: false, - value: node, - }) -} - -/// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: reqwest::Response, -) -> Result { - let status = response.status(); - if status.is_success() { - // We shouldn't log raw body because it may contain secrets. - info!("request succeeded, processing the body"); - return Ok(response.json().await?); - } - - // Don't throw an error here because it's not as important - // as the fact that the request itself has failed. - let body = response.json().await.unwrap_or_else(|e| { - warn!("failed to parse error body: {e}"); - ConsoleError { - error: "reason unclear (malformed error message)".into(), - } - }); - - let text = body.error; - error!("console responded with an error ({status}): {text}"); - Err(ApiError::Console { status, text }) -} - -fn parse_host_port(input: &str) -> Option<(&str, u16)> { - let (host, port) = input.split_once(':')?; - Some((host, port.parse().ok()?)) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_host_port() { - let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse"); - assert_eq!(host, "127.0.0.1"); - assert_eq!(port, 5432); - } -} diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index e16bbc70e4..ef92b1a444 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,5 +1,11 @@ -use super::{AuthSuccess, NodeInfo}; -use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters}; +use super::AuthSuccess; +use crate::{ + auth, compute, + console::{self, provider::NodeInfo}, + error::UserFacingError, + stream::PqStream, + waiters, +}; use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; @@ -47,7 +53,7 @@ pub fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub async fn handle_user( +pub(super) async fn handle_user( link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result> { @@ -55,7 +61,7 @@ pub async fn handle_user( let span = info_span!("link", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); - let db_info = super::with_waiter(psql_session_id, |waiter| async { + let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async { // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); client @@ -80,14 +86,14 @@ pub async fn handle_user( .user(&db_info.user); if let Some(password) = db_info.password { - config.password(password); + config.password(password.as_ref()); } Ok(AuthSuccess { reported_auth_ok: true, value: NodeInfo { config, - aux: db_info.aux, + aux: db_info.aux.into(), }, }) } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 3b71bef9aa..e1943fe44c 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -33,6 +33,7 @@ impl UserFacingError for ClientCredsParseError {} pub struct ClientCredentials<'a> { pub user: &'a str, pub dbname: &'a str, + // TODO: this is a severe misnomer! We should think of a new name ASAP. pub project: Option>, /// If `True`, we'll use the old cleartext password flow. This is used for /// websocket connections, which want to minimize the number of round trips. diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs new file mode 100644 index 0000000000..4e16cc39ec --- /dev/null +++ b/proxy/src/cache.rs @@ -0,0 +1,304 @@ +use std::{ + borrow::Borrow, + hash::Hash, + ops::{Deref, DerefMut}, + time::{Duration, Instant}, +}; +use tracing::debug; + +// This seems to make more sense than `lru` or `cached`: +// +// * `near/nearcore` ditched `cached` in favor of `lru` +// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). +// +// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). +// This severely hinders its usage both in terms of creating wrappers and supported key types. +// +// On the other hand, `hashlink` has good download stats and appears to be maintained. +use hashlink::{linked_hash_map::RawEntryMut, LruCache}; + +/// A generic trait which exposes types of cache's key and value, +/// as well as the notion of cache entry invalidation. +/// This is useful for [`timed_lru::Cached`]. +pub trait Cache { + /// Entry's key. + type Key; + + /// Entry's value. + type Value; + + /// Used for entry invalidation. + type LookupInfo; + + /// Invalidate an entry using a lookup info. + /// We don't have an empty default impl because it's error-prone. + fn invalidate(&self, _: &Self::LookupInfo); +} + +impl Cache for &C { + type Key = C::Key; + type Value = C::Value; + type LookupInfo = C::LookupInfo; + + fn invalidate(&self, info: &Self::LookupInfo) { + C::invalidate(self, info) + } +} + +pub use timed_lru::TimedLru; +pub mod timed_lru { + use super::*; + + /// An implementation of timed LRU cache with fixed capacity. + /// Key properties: + /// + /// * Whenever a new entry is inserted, the least recently accessed one is evicted. + /// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). + /// + /// * When the entry is about to be retrieved, we check its expiration timestamp. + /// If the entry has expired, we remove it from the cache; Otherwise we bump the + /// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong + /// its existence. + /// + /// * There's an API for immediate invalidation (removal) of a cache entry; + /// It's useful in case we know for sure that the entry is no longer correct. + /// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information. + /// + /// * Expired entries are kept in the cache, until they are evicted by the LRU policy, + /// or by a successful lookup (i.e. the entry hasn't expired yet). + /// There is no background job to reap the expired records. + /// + /// * It's possible for an entry that has not yet expired entry to be evicted + /// before expired items. That's a bit wasteful, but probably fine in practice. + pub struct TimedLru { + /// Cache's name for tracing. + name: &'static str, + + /// The underlying cache implementation. + cache: parking_lot::Mutex>>, + + /// Default time-to-live of a single entry. + ttl: Duration, + } + + impl Cache for TimedLru { + type Key = K; + type Value = V; + type LookupInfo = LookupInfo; + + fn invalidate(&self, info: &Self::LookupInfo) { + self.invalidate_raw(info) + } + } + + struct Entry { + created_at: Instant, + expires_at: Instant, + value: T, + } + + impl TimedLru { + /// Construct a new LRU cache with timed entries. + pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self { + Self { + name, + cache: LruCache::new(capacity).into(), + ttl, + } + } + + /// Drop an entry from the cache if it's outdated. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn invalidate_raw(&self, info: &LookupInfo) { + let now = Instant::now(); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let raw_entry = match cache.raw_entry_mut().from_key(&info.key) { + RawEntryMut::Vacant(_) => return, + RawEntryMut::Occupied(x) => x, + }; + + // Remove the entry if it was created prior to lookup timestamp. + let entry = raw_entry.get(); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + let should_remove = created_at <= info.created_at || expires_at <= now; + + if should_remove { + raw_entry.remove(); + } + + drop(cache); // drop lock before logging + debug!( + created_at = format_args!("{created_at:?}"), + expires_at = format_args!("{expires_at:?}"), + entry_removed = should_remove, + "processed a cache entry invalidation event" + ); + } + + /// Try retrieving an entry by its key, then execute `extract` if it exists. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let now = Instant::now(); + let deadline = now.checked_add(self.ttl).expect("time overflow"); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let mut raw_entry = match cache.raw_entry_mut().from_key(key) { + RawEntryMut::Vacant(_) => return None, + RawEntryMut::Occupied(x) => x, + }; + + // Immeditely drop the entry if it has expired. + let entry = raw_entry.get(); + if entry.expires_at <= now { + raw_entry.remove(); + return None; + } + + let value = extract(raw_entry.key(), entry); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + + // Update the deadline and the entry's position in the LRU list. + raw_entry.get_mut().expires_at = deadline; + raw_entry.to_back(); + + drop(cache); // drop lock before logging + debug!( + created_at = format_args!("{created_at:?}"), + old_expires_at = format_args!("{expires_at:?}"), + new_expires_at = format_args!("{deadline:?}"), + "accessed a cache entry" + ); + + Some(value) + } + + /// Insert an entry to the cache. If an entry with the same key already + /// existed, return the previous value and its creation timestamp. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { + let created_at = Instant::now(); + let expires_at = created_at.checked_add(self.ttl).expect("time overflow"); + + let entry = Entry { + created_at, + expires_at, + value, + }; + + // Do costly things before taking the lock. + let old = self + .cache + .lock() + .insert(key, entry) + .map(|entry| entry.value); + + debug!( + created_at = format_args!("{created_at:?}"), + expires_at = format_args!("{expires_at:?}"), + replaced = old.is_some(), + "created a cache entry" + ); + + (created_at, old) + } + } + + impl TimedLru { + pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { + let (created_at, old) = self.insert_raw(key.clone(), value.clone()); + + let cached = Cached { + token: Some((self, LookupInfo { created_at, key })), + value, + }; + + (old, cached) + } + } + + impl TimedLru { + /// Retrieve a cached entry in convenient wrapper. + pub fn get(&self, key: &Q) -> Option> + where + K: Borrow + Clone, + Q: Hash + Eq + ?Sized, + { + self.get_raw(key, |key, entry| { + let info = LookupInfo { + created_at: entry.created_at, + key: key.clone(), + }; + + Cached { + token: Some((self, info)), + value: entry.value.clone(), + } + }) + } + } + + /// Lookup information for key invalidation. + pub struct LookupInfo { + /// Time of creation of a cache [`Entry`]. + /// We use this during invalidation lookups to prevent eviction of a newer + /// entry sharing the same key (it might've been inserted by a different + /// task after we got the entry we're trying to invalidate now). + created_at: Instant, + + /// Search by this key. + key: K, + } + + /// Wrapper for convenient entry invalidation. + pub struct Cached { + /// Cache + lookup info. + token: Option<(C, C::LookupInfo)>, + + /// The value itself. + pub value: C::Value, + } + + impl Cached { + /// Place any entry into this wrapper; invalidation will be a no-op. + /// Unfortunately, rust doesn't let us implement [`From`] or [`Into`]. + pub fn new_uncached(value: impl Into) -> Self { + Self { + token: None, + value: value.into(), + } + } + + /// Drop this entry from a cache if it's still there. + pub fn invalidate(&self) { + if let Some((cache, info)) = &self.token { + cache.invalidate(info); + } + } + + /// Tell if this entry is actually cached. + pub fn cached(&self) -> bool { + self.token.is_some() + } + } + + impl Deref for Cached { + type Target = C::Value; + + fn deref(&self) -> &Self::Target { + &self.value + } + } + + impl DerefMut for Cached { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } + } +} diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index b219cd0fa2..c8c0727471 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,6 +1,5 @@ use anyhow::{anyhow, Context}; use hashbrown::HashMap; -use parking_lot::Mutex; use pq_proto::CancelKeyData; use std::net::SocketAddr; use tokio::net::TcpStream; @@ -9,14 +8,15 @@ use tracing::info; /// Enables serving `CancelRequest`s. #[derive(Default)] -pub struct CancelMap(Mutex>>); +pub struct CancelMap(parking_lot::RwLock>>); impl CancelMap { /// Cancel a running query for the corresponding connection. pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> { + // NB: we should immediately release the lock after cloning the token. let cancel_closure = self .0 - .lock() + .read() .get(&key) .and_then(|x| x.clone()) .with_context(|| format!("query cancellation key not found: {key}"))?; @@ -41,14 +41,14 @@ impl CancelMap { // Random key collisions are unlikely to happen here, but they're still possible, // which is why we have to take care not to rewrite an existing key. self.0 - .lock() + .write() .try_insert(key, None) .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?; // This will guarantee that the session gets dropped // as soon as the future is finished. scopeguard::defer! { - self.0.lock().remove(&key); + self.0.write().remove(&key); info!("dropped query cancellation key {key}"); } @@ -59,12 +59,12 @@ impl CancelMap { #[cfg(test)] fn contains(&self, session: &Session) -> bool { - self.0.lock().contains_key(&session.key) + self.0.read().contains_key(&session.key) } #[cfg(test)] fn is_empty(&self) -> bool { - self.0.lock().is_empty() + self.0.read().is_empty() } } @@ -115,7 +115,7 @@ impl Session<'_> { info!("enabling query cancellation for this session"); self.cancel_map .0 - .lock() + .write() .insert(self.key, Some(cancel_closure)); self.key diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 094db73061..0c0cbcde20 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -42,14 +42,65 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. +#[derive(Clone)] #[repr(transparent)] pub struct ConnCfg(Box); +/// Creation and initialization routines. impl ConnCfg { - /// Construct a new connection config. pub fn new() -> Self { Self(Default::default()) } + + /// Reuse password or auth keys from the other config. + pub fn reuse_password(&mut self, other: &Self) { + if let Some(password) = other.get_password() { + self.password(password); + } + + if let Some(keys) = other.get_auth_keys() { + self.auth_keys(keys); + } + } + + /// Apply startup message params to the connection config. + pub fn set_startup_params(&mut self, params: &StartupMessageParams) { + if let Some(options) = params.options_raw() { + // We must drop all proxy-specific parameters. + #[allow(unstable_name_collisions)] + let options: String = options + .filter(|opt| !opt.starts_with("project=")) + .intersperse(" ") // TODO: use impl from std once it's stabilized + .collect(); + + self.options(&options); + } + + if let Some(app_name) = params.get("application_name") { + self.application_name(app_name); + } + + // TODO: This is especially ugly... + if let Some(replication) = params.get("replication") { + use tokio_postgres::config::ReplicationMode; + match replication { + "true" | "on" | "yes" | "1" => { + self.replication_mode(ReplicationMode::Physical); + } + "database" => { + self.replication_mode(ReplicationMode::Logical); + } + _other => {} + } + } + + // TODO: extend the list of the forwarded startup parameters. + // Currently, tokio-postgres doesn't allow us to pass + // arbitrary parameters, but the ones above are a good start. + // + // This and the reverse params problem can be better addressed + // in a bespoke connection machinery (a new library for that sake). + } } impl std::ops::Deref for ConnCfg { @@ -132,50 +183,13 @@ pub struct PostgresConnection { pub stream: TcpStream, /// PostgreSQL connection parameters. pub params: std::collections::HashMap, + /// Query cancellation token. + pub cancel_closure: CancelClosure, } impl ConnCfg { /// Connect to a corresponding compute node. - pub async fn connect( - mut self, - params: &StartupMessageParams, - ) -> Result<(PostgresConnection, CancelClosure), ConnectionError> { - if let Some(options) = params.options_raw() { - // We must drop all proxy-specific parameters. - #[allow(unstable_name_collisions)] - let options: String = options - .filter(|opt| !opt.starts_with("project=")) - .intersperse(" ") // TODO: use impl from std once it's stabilized - .collect(); - - self.0.options(&options); - } - - if let Some(app_name) = params.get("application_name") { - self.0.application_name(app_name); - } - - // TODO: This is especially ugly... - if let Some(replication) = params.get("replication") { - use tokio_postgres::config::ReplicationMode; - match replication { - "true" | "on" | "yes" | "1" => { - self.0.replication_mode(ReplicationMode::Physical); - } - "database" => { - self.0.replication_mode(ReplicationMode::Logical); - } - _other => {} - } - } - - // TODO: extend the list of the forwarded startup parameters. - // Currently, tokio-postgres doesn't allow us to pass - // arbitrary parameters, but the ones above are a good start. - // - // This and the reverse params problem can be better addressed - // in a bespoke connection machinery (a new library for that sake). - + pub async fn connect(&self) -> Result { // TODO: establish a secure connection to the DB. let (socket_addr, mut stream) = self.connect_raw().await?; let (client, connection) = self.0.connect_raw(&mut stream, NoTls).await?; @@ -189,8 +203,13 @@ impl ConnCfg { // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. // Yet another reason to rework the connection establishing code. let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); - let db = PostgresConnection { stream, params }; - Ok((db, cancel_closure)) + let connection = PostgresConnection { + stream, + params, + cancel_closure, + }; + + Ok(connection) } } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 33a8fff847..5e285f3625 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,16 +1,16 @@ use crate::auth; -use anyhow::{ensure, Context}; -use std::sync::Arc; +use anyhow::{bail, ensure, Context}; +use std::{str::FromStr, sync::Arc, time::Duration}; pub struct ProxyConfig { pub tls_config: Option, pub auth_backend: auth::BackendType<'static, ()>, - pub metric_collection_config: Option, + pub metric_collection: Option, } pub struct MetricCollectionConfig { pub endpoint: reqwest::Url, - pub interval: std::time::Duration, + pub interval: Duration, } pub struct TlsConfig { @@ -37,6 +37,7 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result anyhow::Result anyhow::Result { + let mut size = None; + let mut ttl = None; + + for option in options.split(',') { + let (key, value) = option + .split_once('=') + .with_context(|| format!("bad key-value pair: {option}"))?; + + match key { + "size" => size = Some(value.parse()?), + "ttl" => ttl = Some(humantime::parse_duration(value)?), + unknown => bail!("unknown key: {unknown}"), + } + } + + // TTL doesn't matter if cache is always empty. + if let Some(0) = size { + ttl.get_or_insert(Duration::default()); + } + + Ok(Self { + size: size.context("missing `size`")?, + ttl: ttl.context("missing `ttl`")?, + }) + } +} + +impl FromStr for CacheOptions { + type Err = anyhow::Error; + + fn from_str(options: &str) -> Result { + let error = || format!("failed to parse cache options '{options}'"); + Self::parse(options).with_context(error) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_cache_options() -> anyhow::Result<()> { + let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?; + assert_eq!(size, 4096); + assert_eq!(ttl, Duration::from_secs(5 * 60)); + + let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?; + assert_eq!(size, 2); + assert_eq!(ttl, Duration::from_secs(4 * 60)); + + let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?; + assert_eq!(size, 0); + assert_eq!(ttl, Duration::from_secs(1)); + + let CacheOptions { size, ttl } = "size=0".parse()?; + assert_eq!(size, 0); + assert_eq!(ttl, Duration::default()); + + Ok(()) + } +} diff --git a/proxy/src/console.rs b/proxy/src/console.rs index 78f09ac9e1..1f3ef99555 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -3,3 +3,15 @@ /// Payloads used in the console's APIs. pub mod messages; + +/// Wrappers for console APIs and their mocks. +pub mod provider; +pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo}; + +/// Various cache-related types. +pub mod caches { + pub use super::provider::{ApiCaches, NodeInfoCache}; +} + +/// Console's management API. +pub mod mgmt; diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 63a97069b8..0d321c077a 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -63,13 +63,13 @@ impl KickSession<'_> { /// Compute node connection params. #[derive(Deserialize)] pub struct DatabaseInfo { - pub host: String, + pub host: Box, pub port: u16, - pub dbname: String, - pub user: String, + pub dbname: Box, + pub user: Box, /// Console always provides a password, but it might /// be inconvenient for debug with local PG instance. - pub password: Option, + pub password: Option>, pub aux: MetricsAuxInfo, } diff --git a/proxy/src/mgmt.rs b/proxy/src/console/mgmt.rs similarity index 79% rename from proxy/src/mgmt.rs rename to proxy/src/console/mgmt.rs index cf83b48ae0..51a117d3b7 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -1,8 +1,9 @@ use crate::{ - auth, console::messages::{DatabaseInfo, KickSession}, + waiters::{self, Waiter, Waiters}, }; use anyhow::Context; +use once_cell::sync::Lazy; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; use std::{ net::{TcpListener, TcpStream}, @@ -14,6 +15,25 @@ use utils::{ postgres_backend_async::QueryError, }; +static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); + +/// Give caller an opportunity to wait for the cloud's reply. +pub async fn with_waiter( + psql_session_id: impl Into, + action: impl FnOnce(Waiter<'static, ComputeReady>) -> R, +) -> Result +where + R: std::future::Future>, + E: From, +{ + let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; + action(waiter).await +} + +pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { + CPLANE_WAITERS.notify(psql_session_id, msg) +} + /// Console management API listener thread. /// It spawns console response handlers needed for the link auth. pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { @@ -76,7 +96,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query let _enter = span.enter(); info!("got response: {:?}", resp.result); - match auth::backend::notify(resp.session_id, Ok(resp.result)) { + match notify(resp.session_id, Ok(resp.result)) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs new file mode 100644 index 0000000000..7621aba19b --- /dev/null +++ b/proxy/src/console/provider.rs @@ -0,0 +1,194 @@ +pub mod mock; +pub mod neon; + +use super::messages::MetricsAuxInfo; +use crate::{ + auth::ClientCredentials, + cache::{timed_lru, TimedLru}, + compute, scram, +}; +use async_trait::async_trait; +use std::sync::Arc; + +pub mod errors { + use crate::error::{io_error, UserFacingError}; + use reqwest::StatusCode as HttpStatusCode; + use thiserror::Error; + + /// A go-to error message which doesn't leak any detail. + const REQUEST_FAILED: &str = "Console request failed"; + + /// Common console API error. + #[derive(Debug, Error)] + pub enum ApiError { + /// Error returned by the console itself. + #[error("{REQUEST_FAILED} with {}: {}", .status, .text)] + Console { + status: HttpStatusCode, + text: Box, + }, + + /// Various IO errors like broken pipe or malformed payload. + #[error("{REQUEST_FAILED}: {0}")] + Transport(#[from] std::io::Error), + } + + impl ApiError { + /// Returns HTTP status code if it's the reason for failure. + pub fn http_status_code(&self) -> Option { + use ApiError::*; + match self { + Console { status, .. } => Some(*status), + _ => None, + } + } + } + + impl UserFacingError for ApiError { + fn to_string_client(&self) -> String { + use ApiError::*; + match self { + // To minimize risks, only select errors are forwarded to users. + // Ask @neondatabase/control-plane for review before adding more. + Console { status, .. } => match *status { + HttpStatusCode::NOT_FOUND => { + // Status 404: failed to get a project-related resource. + format!("{REQUEST_FAILED}: endpoint cannot be found") + } + HttpStatusCode::NOT_ACCEPTABLE => { + // Status 406: endpoint is disabled (we don't allow connections). + format!("{REQUEST_FAILED}: endpoint is disabled") + } + HttpStatusCode::LOCKED => { + // Status 423: project might be in maintenance mode (or bad state). + format!("{REQUEST_FAILED}: endpoint is temporary unavailable") + } + _ => REQUEST_FAILED.to_owned(), + }, + _ => REQUEST_FAILED.to_owned(), + } + } + } + + // Helps eliminate graceless `.map_err` calls without introducing another ctor. + impl From for ApiError { + fn from(e: reqwest::Error) -> Self { + io_error(e).into() + } + } + + #[derive(Debug, Error)] + pub enum GetAuthInfoError { + // We shouldn't include the actual secret here. + #[error("Console responded with a malformed auth secret")] + BadSecret, + + #[error(transparent)] + ApiError(ApiError), + } + + // This allows more useful interactions than `#[from]`. + impl> From for GetAuthInfoError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } + } + + impl UserFacingError for GetAuthInfoError { + fn to_string_client(&self) -> String { + use GetAuthInfoError::*; + match self { + // We absolutely should not leak any secrets! + BadSecret => REQUEST_FAILED.to_owned(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), + } + } + } + #[derive(Debug, Error)] + pub enum WakeComputeError { + #[error("Console responded with a malformed compute address: {0}")] + BadComputeAddress(Box), + + #[error(transparent)] + ApiError(ApiError), + } + + // This allows more useful interactions than `#[from]`. + impl> From for WakeComputeError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } + } + + impl UserFacingError for WakeComputeError { + fn to_string_client(&self) -> String { + use WakeComputeError::*; + match self { + // We shouldn't show user the address even if it's broken. + // Besides, user is unlikely to care about this detail. + BadComputeAddress(_) => REQUEST_FAILED.to_owned(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), + } + } + } +} + +/// Extra query params we'd like to pass to the console. +pub struct ConsoleReqExtra<'a> { + /// A unique identifier for a connection. + pub session_id: uuid::Uuid, + /// Name of client application, if set. + pub application_name: Option<&'a str>, +} + +/// Auth secret which is managed by the cloud. +pub enum AuthInfo { + /// Md5 hash of user's password. + Md5([u8; 16]), + + /// [SCRAM](crate::scram) authentication info. + Scram(scram::ServerSecret), +} + +/// Info for establishing a connection to a compute node. +/// This is what we get after auth succeeded, but not before! +#[derive(Clone)] +pub struct NodeInfo { + /// Compute node connection params. + /// It's sad that we have to clone this, but this will improve + /// once we migrate to a bespoke connection logic. + pub config: compute::ConnCfg, + + /// Labels for proxy's metrics. + pub aux: Arc, +} + +pub type NodeInfoCache = TimedLru, NodeInfo>; +pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; + +/// This will allocate per each call, but the http requests alone +/// already require a few allocations, so it should be fine. +#[async_trait] +pub trait Api { + /// Get the client's auth secret for authentication. + async fn get_auth_info( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result, errors::GetAuthInfoError>; + + /// Wake up the compute node and return the corresponding connection info. + async fn wake_compute( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result; +} + +/// Various caches for [`console`]. +pub struct ApiCaches { + /// Cache for the `wake_compute` API method. + pub node_info: NodeInfoCache, +} diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/console/provider/mock.rs similarity index 53% rename from proxy/src/auth/backend/postgres.rs rename to proxy/src/console/provider/mock.rs index 260342f103..301c3be516 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/console/provider/mock.rs @@ -1,21 +1,14 @@ -//! Local mock of Cloud API V2. +//! Mock console backend which relies on a user-provided postgres instance. use super::{ - console::{self, AuthInfo, GetAuthInfoError, WakeComputeError}, - AuthSuccess, NodeInfo, -}; -use crate::{ - auth::{self, ClientCredentials}, - compute, - error::io_error, - scram, - stream::PqStream, - url::ApiUrl, + errors::{ApiError, GetAuthInfoError, WakeComputeError}, + AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; +use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; +use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{info, info_span, warn, Instrument}; +use tracing::{error, info, info_span, warn, Instrument}; #[derive(Debug, Error)] enum MockApiError { @@ -23,49 +16,36 @@ enum MockApiError { PasswordNotSet(tokio_postgres::Error), } -impl From for console::ApiError { +impl From for ApiError { fn from(e: MockApiError) -> Self { io_error(e).into() } } -impl From for console::ApiError { +impl From for ApiError { fn from(e: tokio_postgres::Error) -> Self { io_error(e).into() } } -#[must_use] -pub(super) struct Api<'a> { - endpoint: &'a ApiUrl, - creds: &'a ClientCredentials<'a>, +#[derive(Clone)] +pub struct Api { + endpoint: ApiUrl, } -impl<'a> AsRef> for Api<'a> { - fn as_ref(&self) -> &ClientCredentials<'a> { - self.creds - } -} - -impl<'a> Api<'a> { - /// Construct an API object containing the auth parameters. - pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self { - Self { endpoint, creds } +impl Api { + pub fn new(endpoint: ApiUrl) -> Self { + Self { endpoint } } - /// Authenticate the existing user or throw an error. - pub(super) async fn handle_user( - &'a self, - client: &mut PqStream, - ) -> auth::Result> { - // We reuse user handling logic from a production module. - console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await + pub fn url(&self) -> &str { + self.endpoint.as_str() } -} -impl Api<'_> { - /// This implementation fetches the auth info from a local postgres instance. - async fn get_auth_info(&self) -> Result, GetAuthInfoError> { + async fn do_get_auth_info( + &self, + creds: &ClientCredentials<'_>, + ) -> Result, GetAuthInfoError> { async { // Perhaps we could persist this connection, but then we'd have to // write more code for reopening it if it got closed, which doesn't @@ -75,7 +55,7 @@ impl Api<'_> { tokio::spawn(connection); let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; - let rows = client.query(query, &[&self.creds.user]).await?; + let rows = client.query(query, &[&creds.user]).await?; // We can get at most one row, because `rolname` is unique. let row = match rows.get(0) { @@ -84,7 +64,7 @@ impl Api<'_> { // However, this is still a *valid* outcome which is very similar // to getting `404 Not found` from the Neon console. None => { - warn!("user '{}' does not exist", self.creds.user); + warn!("user '{}' does not exist", creds.user); return Ok(None); } }; @@ -98,23 +78,50 @@ impl Api<'_> { Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5))) } .map_err(crate::error::log_error) - .instrument(info_span!("get_auth_info", mock = self.endpoint.as_str())) + .instrument(info_span!("postgres", url = self.endpoint.as_str())) .await } - /// We don't need to wake anything locally, so we just return the connection info. - pub async fn wake_compute(&self) -> Result { + async fn do_wake_compute( + &self, + creds: &ClientCredentials<'_>, + ) -> Result { let mut config = compute::ConnCfg::new(); config .host(self.endpoint.host_str().unwrap_or("localhost")) .port(self.endpoint.port().unwrap_or(5432)) - .dbname(self.creds.dbname) - .user(self.creds.user); + .dbname(creds.dbname) + .user(creds.user); - Ok(NodeInfo { + let node = NodeInfo { config, aux: Default::default(), - }) + }; + + Ok(node) + } +} + +#[async_trait] +impl super::Api for Api { + #[tracing::instrument(skip_all)] + async fn get_auth_info( + &self, + _extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result, GetAuthInfoError> { + self.do_get_auth_info(creds).await + } + + #[tracing::instrument(skip_all)] + async fn wake_compute( + &self, + _extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result { + self.do_wake_compute(creds) + .map_ok(CachedNodeInfo::new_uncached) + .await } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs new file mode 100644 index 0000000000..00d3ca8352 --- /dev/null +++ b/proxy/src/console/provider/neon.rs @@ -0,0 +1,196 @@ +//! Production console backend. + +use super::{ + super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, + errors::{ApiError, GetAuthInfoError, WakeComputeError}, + ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, +}; +use crate::{auth::ClientCredentials, compute, http, scram}; +use async_trait::async_trait; +use futures::TryFutureExt; +use reqwest::StatusCode as HttpStatusCode; +use tracing::{error, info, info_span, warn, Instrument}; + +#[derive(Clone)] +pub struct Api { + endpoint: http::Endpoint, + caches: &'static ApiCaches, +} + +impl Api { + /// Construct an API object containing the auth parameters. + pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self { + Self { endpoint, caches } + } + + pub fn url(&self) -> &str { + self.endpoint.url().as_str() + } + + async fn do_get_auth_info( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result, GetAuthInfoError> { + let request_id = uuid::Uuid::new_v4().to_string(); + async { + let request = self + .endpoint + .get("proxy_get_role_secret") + .header("X-Request-ID", &request_id) + .query(&[("session_id", extra.session_id)]) + .query(&[ + ("application_name", extra.application_name), + ("project", Some(creds.project().expect("impossible"))), + ("role", Some(creds.user)), + ]) + .build()?; + + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = match parse_body::(response).await { + Ok(body) => body, + // Error 404 is special: it's ok not to have a secret. + Err(e) => match e.http_status_code() { + Some(HttpStatusCode::NOT_FOUND) => return Ok(None), + _otherwise => return Err(e.into()), + }, + }; + + let secret = scram::ServerSecret::parse(&body.role_secret) + .map(AuthInfo::Scram) + .ok_or(GetAuthInfoError::BadSecret)?; + + Ok(Some(secret)) + } + .map_err(crate::error::log_error) + .instrument(info_span!("http", id = request_id)) + .await + } + + async fn do_wake_compute( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result { + let project = creds.project().expect("impossible"); + let request_id = uuid::Uuid::new_v4().to_string(); + async { + let request = self + .endpoint + .get("proxy_wake_compute") + .header("X-Request-ID", &request_id) + .query(&[("session_id", extra.session_id)]) + .query(&[ + ("application_name", extra.application_name), + ("project", Some(project)), + ]) + .build()?; + + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = parse_body::(response).await?; + + // Unfortunately, ownership won't let us use `Option::ok_or` here. + let (host, port) = match parse_host_port(&body.address) { + None => return Err(WakeComputeError::BadComputeAddress(body.address)), + Some(x) => x, + }; + + let mut config = compute::ConnCfg::new(); + config + .host(host) + .port(port) + .dbname(creds.dbname) + .user(creds.user); + + let node = NodeInfo { + config, + aux: body.aux.into(), + }; + + Ok(node) + } + .map_err(crate::error::log_error) + .instrument(info_span!("http", id = request_id)) + .await + } +} + +#[async_trait] +impl super::Api for Api { + #[tracing::instrument(skip_all)] + async fn get_auth_info( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result, GetAuthInfoError> { + self.do_get_auth_info(extra, creds).await + } + + #[tracing::instrument(skip_all)] + async fn wake_compute( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + ) -> Result { + let key = creds.project().expect("impossible"); + + // Every time we do a wakeup http request, the compute node will stay up + // for some time (highly depends on the console's scale-to-zero policy); + // The connection info remains the same during that period of time, + // which means that we might cache it to reduce the load and latency. + if let Some(cached) = self.caches.node_info.get(key) { + info!(key = key, "found cached compute node info"); + return Ok(cached); + } + + let node = self.do_wake_compute(extra, creds).await?; + let (_, cached) = self.caches.node_info.insert(key.into(), node); + info!(key = key, "created a cache entry for compute node info"); + + Ok(cached) + } +} + +/// Parse http response body, taking status code into account. +async fn parse_body serde::Deserialize<'a>>( + response: reqwest::Response, +) -> Result { + let status = response.status(); + if status.is_success() { + // We shouldn't log raw body because it may contain secrets. + info!("request succeeded, processing the body"); + return Ok(response.json().await?); + } + + // Don't throw an error here because it's not as important + // as the fact that the request itself has failed. + let body = response.json().await.unwrap_or_else(|e| { + warn!("failed to parse error body: {e}"); + ConsoleError { + error: "reason unclear (malformed error message)".into(), + } + }); + + let text = body.error; + error!("console responded with an error ({status}): {text}"); + Err(ApiError::Console { status, text }) +} + +fn parse_host_port(input: &str) -> Option<(&str, u16)> { + let (host, port) = input.split_once(':')?; + Some((host, port.parse().ok()?)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_host_port() { + let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse"); + assert_eq!(host, "127.0.0.1"); + assert_eq!(port, 5432); + } +} diff --git a/proxy/src/http/server.rs b/proxy/src/http/server.rs index 05f6feb307..f35f4f9a62 100644 --- a/proxy/src/http/server.rs +++ b/proxy/src/http/server.rs @@ -9,8 +9,7 @@ async fn status_handler(_: Request) -> Result, ApiError> { } fn make_router() -> RouterBuilder { - let router = endpoint::make_router(); - router.get("/v1/status", status_handler) + endpoint::make_router().get("/v1/status", status_handler) } pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index d43602e833..bedded7567 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -1,6 +1,6 @@ use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt}; -use hyper::server::accept::{self}; +use hyper::server::accept; use hyper::server::conn::AddrIncoming; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; @@ -161,7 +161,7 @@ impl AsyncBufRead for WebSocketRW { async fn serve_websocket( websocket: HyperWebsocket, - config: &ProxyConfig, + config: &'static ProxyConfig, cancel_map: &CancelMap, session_id: uuid::Uuid, hostname: Option, diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 1b61ab108f..c96ca2a171 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -5,6 +5,7 @@ //! in somewhat transparent manner (again via communication with control plane API). mod auth; +mod cache; mod cancellation; mod compute; mod config; @@ -12,7 +13,6 @@ mod console; mod error; mod http; mod metrics; -mod mgmt; mod parse; mod proxy; mod sasl; @@ -21,7 +21,6 @@ mod stream; mod url; mod waiters; -use ::metrics::set_build_info_metric; use anyhow::{bail, Context}; use clap::{self, Arg}; use config::ProxyConfig; @@ -29,8 +28,7 @@ use futures::FutureExt; use std::{borrow::Cow, future::Future, net::SocketAddr}; use tokio::{net::TcpListener, task::JoinError}; use tracing::{info, info_span, Instrument}; -use utils::project_git_version; -use utils::sentry_init::init_sentry; +use utils::{project_git_version, sentry_init::init_sentry}; project_git_version!(GIT_VERSION); @@ -51,124 +49,133 @@ async fn main() -> anyhow::Result<()> { // initialize sentry if SENTRY_DSN is provided let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - let arg_matches = cli().get_matches(); - - let tls_config = match ( - arg_matches.get_one::("tls-key"), - arg_matches.get_one::("tls-cert"), - ) { - (Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?), - (None, None) => None, - _ => bail!("either both or neither tls-key and tls-cert must be specified"), - }; - - let proxy_address: SocketAddr = arg_matches.get_one::("proxy").unwrap().parse()?; - let mgmt_address: SocketAddr = arg_matches.get_one::("mgmt").unwrap().parse()?; - let http_address: SocketAddr = arg_matches.get_one::("http").unwrap().parse()?; - - let metric_collection_config = match - ( - arg_matches.get_one::("metric-collection-endpoint"), - arg_matches.get_one::("metric-collection-interval"), - ) { - - (Some(endpoint), Some(interval)) => { - Some(config::MetricCollectionConfig { - endpoint: endpoint.parse()?, - interval: humantime::parse_duration(interval)?, - }) - } - (None, None) => None, - _ => bail!("either both or neither metric-collection-endpoint and metric-collection-interval must be specified"), - }; - - let auth_backend = match arg_matches - .get_one::("auth-backend") - .unwrap() - .as_str() - { - "console" => { - let url = arg_matches - .get_one::("auth-endpoint") - .unwrap() - .parse()?; - let endpoint = http::Endpoint::new(url, reqwest::Client::new()); - auth::BackendType::Console(Cow::Owned(endpoint), ()) - } - "postgres" => { - let url = arg_matches - .get_one::("auth-endpoint") - .unwrap() - .parse()?; - auth::BackendType::Postgres(Cow::Owned(url), ()) - } - "link" => { - let url = arg_matches.get_one::("uri").unwrap().parse()?; - auth::BackendType::Link(Cow::Owned(url)) - } - other => bail!("unsupported auth backend: {other}"), - }; - - let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig { - tls_config, - auth_backend, - metric_collection_config, - })); - info!("Version: {GIT_VERSION}"); + ::metrics::set_build_info_metric(GIT_VERSION); + + let args = cli().get_matches(); + let config = build_config(&args)?; + info!("Authentication backend: {}", config.auth_backend); // Check that we can bind to address before further initialization + let http_address: SocketAddr = args.get_one::("http").unwrap().parse()?; info!("Starting http on {http_address}"); let http_listener = TcpListener::bind(http_address).await?.into_std()?; + let mgmt_address: SocketAddr = args.get_one::("mgmt").unwrap().parse()?; info!("Starting mgmt on {mgmt_address}"); let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?; + let proxy_address: SocketAddr = args.get_one::("proxy").unwrap().parse()?; info!("Starting proxy on {proxy_address}"); let proxy_listener = TcpListener::bind(proxy_address).await?; let mut tasks = vec![ tokio::spawn(http::server::task_main(http_listener)), tokio::spawn(proxy::task_main(config, proxy_listener)), - tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)), + tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)), ]; - if let Some(wss_address) = arg_matches.get_one::("wss") { + if let Some(wss_address) = args.get_one::("wss") { let wss_address: SocketAddr = wss_address.parse()?; - info!("Starting wss on {}", wss_address); + info!("Starting wss on {wss_address}"); let wss_listener = TcpListener::bind(wss_address).await?; + tasks.push(tokio::spawn(http::websocket::task_main( wss_listener, config, ))); } - if let Some(metric_collection_config) = &config.metric_collection_config { + // TODO: refactor. + if let Some(metric_collection) = &config.metric_collection { let hostname = hostname::get()? .into_string() .map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?; tasks.push(tokio::spawn( metrics::collect_metrics( - &metric_collection_config.endpoint, - metric_collection_config.interval, + &metric_collection.endpoint, + metric_collection.interval, hostname, ) .instrument(info_span!("collect_metrics")), )); } - let tasks = tasks.into_iter().map(flatten_err); - - set_build_info_metric(GIT_VERSION); // This will block until all tasks have completed. // Furthermore, the first one to fail will cancel the rest. + let tasks = tasks.into_iter().map(flatten_err); let _: Vec<()> = futures::future::try_join_all(tasks).await?; Ok(()) } +/// ProxyConfig is created at proxy startup, and lives forever. +fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> { + let tls_config = match ( + args.get_one::("tls-key"), + args.get_one::("tls-cert"), + ) { + (Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?), + (None, None) => None, + _ => bail!("either both or neither tls-key and tls-cert must be specified"), + }; + + let metric_collection = match ( + args.get_one::("metric-collection-endpoint"), + args.get_one::("metric-collection-interval"), + ) { + (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig { + endpoint: endpoint.parse()?, + interval: humantime::parse_duration(interval)?, + }), + (None, None) => None, + _ => bail!( + "either both or neither metric-collection-endpoint \ + and metric-collection-interval must be specified" + ), + }; + + let auth_backend = match args.get_one::("auth-backend").unwrap().as_str() { + "console" => { + let config::CacheOptions { size, ttl } = args + .get_one::("wake-compute-cache") + .unwrap() + .parse()?; + + info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}"); + let caches = Box::leak(Box::new(console::caches::ApiCaches { + node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl), + })); + + let url = args.get_one::("auth-endpoint").unwrap().parse()?; + let endpoint = http::Endpoint::new(url, reqwest::Client::new()); + + let api = console::provider::neon::Api::new(endpoint, caches); + auth::BackendType::Console(Cow::Owned(api), ()) + } + "postgres" => { + let url = args.get_one::("auth-endpoint").unwrap().parse()?; + let api = console::provider::mock::Api::new(url); + auth::BackendType::Postgres(Cow::Owned(api), ()) + } + "link" => { + let url = args.get_one::("uri").unwrap().parse()?; + auth::BackendType::Link(Cow::Owned(url)) + } + other => bail!("unsupported auth backend: {other}"), + }; + + let config = Box::leak(Box::new(ProxyConfig { + tls_config, + auth_backend, + metric_collection, + })); + + Ok(config) +} + fn cli() -> clap::Command { clap::Command::new("Neon proxy/router") .disable_help_flag(true) @@ -235,16 +242,27 @@ fn cli() -> clap::Command { .arg( Arg::new("metric-collection-endpoint") .long("metric-collection-endpoint") - .help("metric collection HTTP endpoint"), + .help("http endpoint to receive periodic metric updates"), ) .arg( Arg::new("metric-collection-interval") .long("metric-collection-interval") - .help("metric collection interval"), + .help("how often metrics should be sent to a collection endpoint"), + ) + .arg( + Arg::new("wake-compute-cache") + .long("wake-compute-cache") + .help("cache for `wake_compute` api method (use `size=0` to disable)") + .default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO), ) } -#[test] -fn verify_cli() { - cli().debug_assert(); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_cli() { + cli().debug_assert(); + } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 63573d49c0..a622a35e6d 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,9 +2,12 @@ mod tests; use crate::{ - auth, + auth::{self, backend::AuthSuccess}, cancellation::{self, CancelMap}, + compute::{self, PostgresConnection}, config::{ProxyConfig, TlsConfig}, + console::{self, messages::MetricsAuxInfo}, + error::io_error, stream::{MeasuredStream, PqStream, Stream}, }; use anyhow::{bail, Context}; @@ -14,7 +17,10 @@ use once_cell::sync::Lazy; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, info_span, Instrument}; +use tracing::{error, info, info_span, warn, Instrument}; + +/// Number of times we should retry the `/proxy_wake_compute` http request. +const NUM_RETRIES_WAKE_COMPUTE: usize = 1; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; const ERR_PROTO_VIOLATION: &str = "protocol violation"; @@ -35,6 +41,15 @@ static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy = Lazy::new(|| { .unwrap() }); +static NUM_CONNECTION_FAILURES: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_connection_failures_total", + "Number of connection failures (per kind).", + &["kind"], + ) + .unwrap() +}); + static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { register_int_counter_vec!( "proxy_io_bytes_per_client", @@ -82,11 +97,12 @@ pub async fn task_main( } } +// TODO(tech debt): unite this with its twin below. pub async fn handle_ws_client( - config: &ProxyConfig, + config: &'static ProxyConfig, cancel_map: &CancelMap, session_id: uuid::Uuid, - stream: impl AsyncRead + AsyncWrite + Unpin + Send, + stream: impl AsyncRead + AsyncWrite + Unpin, hostname: Option, ) -> anyhow::Result<()> { // The `closed` counter will increase when this future is destroyed. @@ -99,7 +115,7 @@ pub async fn handle_ws_client( let hostname = hostname.as_deref(); // TLS is None here, because the connection is already encrypted. - let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake")); + let do_handshake = handshake(stream, None, cancel_map); let (mut stream, params) = match do_handshake.await? { Some(x) => x, None => return Ok(()), // it's a cancellation request @@ -124,10 +140,10 @@ pub async fn handle_ws_client( } async fn handle_client( - config: &ProxyConfig, + config: &'static ProxyConfig, cancel_map: &CancelMap, session_id: uuid::Uuid, - stream: impl AsyncRead + AsyncWrite + Unpin + Send, + stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { // The `closed` counter will increase when this future is destroyed. NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); @@ -136,7 +152,7 @@ async fn handle_client( } let tls = config.tls_config.as_ref(); - let do_handshake = handshake(stream, tls, cancel_map).instrument(info_span!("handshake")); + let do_handshake = handshake(stream, tls, cancel_map); let (mut stream, params) = match do_handshake.await? { Some(x) => x, None => return Ok(()), // it's a cancellation request @@ -165,6 +181,7 @@ async fn handle_client( /// For better testing experience, `stream` can be any object satisfying the traits. /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. +#[tracing::instrument(skip_all)] async fn handshake( stream: S, mut tls: Option<&TlsConfig>, @@ -226,6 +243,133 @@ async fn handshake( } } +/// Try to connect to the compute node once. +#[tracing::instrument(name = "connect_once", skip_all)] +async fn connect_to_compute_once( + node_info: &console::CachedNodeInfo, +) -> Result { + // If we couldn't connect, a cached connection info might be to blame + // (e.g. the compute node's address might've changed at the wrong time). + // Invalidate the cache entry (if any) to prevent subsequent errors. + let invalidate_cache = |_: &compute::ConnectionError| { + let is_cached = node_info.cached(); + if is_cached { + warn!("invalidating stalled compute node info cache entry"); + node_info.invalidate(); + } + + let label = match is_cached { + true => "compute_cached", + false => "compute_uncached", + }; + NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); + }; + + node_info + .config + .connect() + .inspect_err(invalidate_cache) + .await +} + +/// Try to connect to the compute node, retrying if necessary. +/// This function might update `node_info`, so we take it by `&mut`. +#[tracing::instrument(skip_all)] +async fn connect_to_compute( + node_info: &mut console::CachedNodeInfo, + params: &StartupMessageParams, + extra: &console::ConsoleReqExtra<'_>, + creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, +) -> Result { + let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; + loop { + // Apply startup params to the (possibly, cached) compute node info. + node_info.config.set_startup_params(params); + match connect_to_compute_once(node_info).await { + Err(e) if num_retries > 0 => { + info!("compute node's state has changed; requesting a wake-up"); + match creds.wake_compute(extra).map_err(io_error).await? { + // Update `node_info` and try one more time. + Some(mut new) => { + new.config.reuse_password(&node_info.config); + *node_info = new; + } + // Link auth doesn't work that way, so we just exit. + None => return Err(e), + } + } + other => return other, + } + + num_retries -= 1; + info!("retrying after wake-up ({num_retries} attempts left)"); + } +} + +/// Finish client connection initialization: confirm auth success, send params, etc. +#[tracing::instrument(skip_all)] +async fn prepare_client_connection( + node: &compute::PostgresConnection, + reported_auth_ok: bool, + session: cancellation::Session<'_>, + stream: &mut PqStream, +) -> anyhow::Result<()> { + // Register compute's query cancellation token and produce a new, unique one. + // The new token (cancel_key_data) will be sent to the client. + let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); + + // Report authentication success if we haven't done this already. + // Note that we do this only (for the most part) after we've connected + // to a compute (see above) which performs its own authentication. + if !reported_auth_ok { + stream.write_message_noflush(&Be::AuthenticationOk)?; + } + + // Forward all postgres connection params to the client. + // Right now the implementation is very hacky and inefficent (ideally, + // we don't need an intermediate hashmap), but at least it should be correct. + for (name, value) in &node.params { + // TODO: Theoretically, this could result in a big pile of params... + stream.write_message_noflush(&Be::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + })?; + } + + stream + .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? + .write_message(&Be::ReadyForQuery) + .await?; + + Ok(()) +} + +/// Forward bytes in both directions (client <-> compute). +#[tracing::instrument(skip_all)] +async fn proxy_pass( + client: impl AsyncRead + AsyncWrite + Unpin, + compute: impl AsyncRead + AsyncWrite + Unpin, + aux: &MetricsAuxInfo, +) -> anyhow::Result<()> { + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx")); + let mut client = MeasuredStream::new(client, |cnt| { + // Number of bytes we sent to the client (outbound). + m_sent.inc_by(cnt as u64); + }); + + let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx")); + let mut compute = MeasuredStream::new(compute, |cnt| { + // Number of bytes the client sent to the compute node (inbound). + m_recv.inc_by(cnt as u64); + }); + + // Starting from here we only proxy the client's traffic. + info!("performing the proxy pass..."); + let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + + Ok(()) +} + /// Thin connection context. struct Client<'a, S> { /// The underlying libpq protocol stream. @@ -255,17 +399,17 @@ impl<'a, S> Client<'a, S> { } } -impl Client<'_, S> { +impl Client<'_, S> { /// Let the client authenticate and connect to the designated compute node. async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> { let Self { mut stream, - creds, + mut creds, params, session_id, } = self; - let extra = auth::ConsoleReqExtra { + let extra = console::ConsoleReqExtra { session_id, // aka this connection's id application_name: params.get("application_name"), }; @@ -278,54 +422,16 @@ impl Client<'_, S> { .instrument(info_span!("auth")) .await?; - let node = auth_result.value; - let (db, cancel_closure) = node - .config - .connect(params) + let AuthSuccess { + reported_auth_ok, + value: mut node_info, + } = auth_result; + + let node = connect_to_compute(&mut node_info, params, &extra, &creds) .or_else(|e| stream.throw_error(e)) .await?; - let cancel_key_data = session.enable_query_cancellation(cancel_closure); - - // Report authentication success if we haven't done this already. - // Note that we do this only (for the most part) after we've connected - // to a compute (see above) which performs its own authentication. - if !auth_result.reported_auth_ok { - stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - // Forward all postgres connection params to the client. - // Right now the implementation is very hacky and inefficent (ideally, - // we don't need an intermediate hashmap), but at least it should be correct. - for (name, value) in &db.params { - // TODO: Theoretically, this could result in a big pile of params... - stream.write_message_noflush(&Be::ParameterStatus { - name: name.as_bytes(), - value: value.as_bytes(), - })?; - } - - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx")); - let mut client = MeasuredStream::new(stream.into_inner(), |cnt| { - // Number of bytes we sent to the client (outbound). - m_sent.inc_by(cnt as u64); - }); - - let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx")); - let mut db = MeasuredStream::new(db.stream, |cnt| { - // Number of bytes the client sent to the compute node (inbound). - m_recv.inc_by(cnt as u64); - }); - - // Starting from here we only proxy the client's traffic. - info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; - - Ok(()) + prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?; + proxy_pass(stream.into_inner(), node.stream, &node_info.aux).await } } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 3a852b2207..8bfc9e2770 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -25,6 +25,7 @@ futures-channel = { version = "0.3", features = ["sink"] } futures-executor = { version = "0.3" } futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } +hashbrown = { version = "0.12", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits"] } @@ -58,6 +59,7 @@ url = { version = "2", features = ["serde"] } anyhow = { version = "1", features = ["backtrace"] } bytes = { version = "1", features = ["serde"] } either = { version = "1" } +hashbrown = { version = "0.12", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits"] }