diff --git a/Cargo.lock b/Cargo.lock index 83afdaf66f..f5f32e8491 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2247,11 +2247,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.2" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] @@ -3936,6 +3936,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "rand 0.8.5", + "serde", "thiserror", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index ebc3dfa7b1..74dbaf1853 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,7 +80,7 @@ futures-core = "0.3" futures-util = "0.3" git-version = "0.3" hashbrown = "0.13" -hashlink = "0.8.1" +hashlink = "0.8.4" hdrhistogram = "7.5.2" hex = "0.4" hex-literal = "0.4" diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index b286eb0358..6eeb3bafef 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -13,5 +13,6 @@ rand.workspace = true tokio.workspace = true tracing.workspace = true thiserror.workspace = true +serde.workspace = true workspace_hack.workspace = true diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index c52a21bcd3..522b65f5d1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -7,6 +7,7 @@ pub mod framed; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; use std::{borrow::Cow, collections::HashMap, fmt, io, str}; // re-export for use in utils pageserver_feedback.rs @@ -123,7 +124,7 @@ impl StartupMessageParams { } } -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub struct CancelKeyData { pub backend_pid: i32, pub cancel_key: i32, diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 48de4e2353..c8028d1bf0 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -36,9 +36,6 @@ pub enum AuthErrorImpl { #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), - #[error(transparent)] - WakeCompute(#[from] console::errors::WakeComputeError), - /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -119,7 +116,6 @@ impl UserFacingError for AuthError { match self.0.as_ref() { Link(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), - WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), @@ -139,7 +135,6 @@ impl ReportableError for AuthError { match self.0.as_ref() { Link(e) => e.get_error_kind(), GetAuthInfo(e) => e.get_error_kind(), - WakeCompute(e) => e.get_error_kind(), Sasl(e) => e.get_error_kind(), AuthFailed(_) => crate::error::ErrorKind::User, BadAuthMethod(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index fa2782bee3..47c1dc4e92 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange; use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; -use crate::console::AuthSecret; +use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -26,7 +26,6 @@ use crate::{ stream, url, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; -use futures::TryFutureExt; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -56,11 +55,11 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T> { +pub enum BackendType<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>), + Link(MaybeOwned<'a, url::ApiUrl>, D), } pub trait TestBackend: Send + Sync + 'static { @@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, ()> { +impl std::fmt::Display for BackendType<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { @@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), } } } -impl BackendType<'_, T> { +impl BackendType<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T> { + pub fn as_ref(&self) -> BackendType<'_, &T, &D> { use BackendType::*; match self { Console(c, x) => Console(MaybeOwned::Borrowed(c), x), - Link(c) => Link(MaybeOwned::Borrowed(c)), + Link(c, x) => Link(MaybeOwned::Borrowed(c), x), } } } -impl<'a, T> BackendType<'a, T> { +impl<'a, T, D> BackendType<'a, T, D> { /// Very similar to [`std::option::Option::map`]. /// Maps [`BackendType`] to [`BackendType`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> { + pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), - Link(c) => Link(c), + Link(c, x) => Link(c, x), } } } - -impl<'a, T, E> BackendType<'a, Result> { +impl<'a, T, D, E> BackendType<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { + pub fn transpose(self) -> Result, E> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), - Link(c) => Ok(Link(c)), + Link(c, x) => Ok(Link(c, x)), } } } -pub struct ComputeCredentials { +pub struct ComputeCredentials { pub info: ComputeUserInfo, - pub keys: T, + pub keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] @@ -153,7 +151,6 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -188,19 +185,21 @@ async fn auth_quirks( client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. let (info, unauthenticated_password) = match user_info.try_into() { Err(info) => { - let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) - .await?; + let res = hacks::password_hack_no_authentication(ctx, info, client).await?; ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); - - (res.info, Some(res.keys)) + let password = match res.keys { + ComputeCredentialKeys::Password(p) => p, + _ => unreachable!("password hack should return a password"), + }; + (res.info, Some(password)) } Ok(info) => (info, None), }; @@ -254,7 +253,7 @@ async fn authenticate_with_secret( unauthenticated_password: Option>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { if let Some(password) = unauthenticated_password { let auth_outcome = validate_password_and_exchange(&password, secret)?; let keys = match auth_outcome { @@ -276,21 +275,22 @@ async fn authenticate_with_secret( // Perform cleartext auth if we're allowed to do that. // Currently, we use it for websocket connections (latency). if allow_cleartext { - return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await; + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + return hacks::authenticate_cleartext(ctx, info, client, secret).await; } // Finally, proceed with the main auth flow (SCRAM-based). - classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await + classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { +impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { Console(_, user_info) => user_info.endpoint_id.clone(), - Link(_) => Some("link".into()), + Link(_, _) => Some("link".into()), } } @@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, - Link(_) => "link", + Link(_, _) => "link", } } @@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, - ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { + ) -> auth::Result> { use BackendType::*; let res = match self { @@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let compute_credentials = + let credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - - let mut num_retries = 0; - let mut node = - wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - (node, BackendType::Console(api, compute_credentials.info)) + BackendType::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Link(url) => { + Link(url, _) => { info!("performing link authentication"); - let node_info = link::authenticate(ctx, &url, client).await?; + let info = link::authenticate(ctx, &url, client).await?; - ( - CachedNodeInfo::new_uncached(node_info), - BackendType::Link(url), - ) + BackendType::Link(url, info) } }; @@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } } -impl BackendType<'_, ComputeUserInfo> { +impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_role_secret( &self, ctx: &mut RequestMonitoring, @@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Link(_) => Ok(Cached::new_uncached(None)), + Link(_, _) => Ok(Cached::new_uncached(None)), } } @@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - } - } - - /// When applicable, wake the compute node, gaining its connection info in the process. - /// The link auth flow doesn't support this, so we return [`None`] in that case. - pub async fn wake_compute( - &self, - ctx: &mut RequestMonitoring, - ) -> Result, console::errors::WakeComputeError> { - use BackendType::*; - - match self { - Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, - Link(_) => Ok(None), + Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, info) => Ok(Cached::new_uncached(info.clone())), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 745dd75107..d075331846 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -4,7 +4,7 @@ use crate::{ compute, config::AuthenticationConfig, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{PqStream, Stream}, }; @@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( + ctx: &mut RequestMonitoring, creds: ComputeUserInfo, client: &mut PqStream>, config: &'static AuthenticationConfig, - latency_timer: &mut LatencyTimer, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] @@ -27,13 +27,11 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { info!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret); + let scram = auth::Scram(&secret, &mut *ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, async { - // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); flow.begin(scram).await.map_err(|error| { warn!(?error, "error sending scram acknowledgement"); diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index b6c1a92d3c..26cf7a01f2 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -4,7 +4,7 @@ use super::{ use crate::{ auth::{self, AuthFlow}, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{self, Stream}, }; @@ -16,15 +16,16 @@ use tracing::{info, warn}; /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. pub async fn authenticate_cleartext( + ctx: &mut RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { warn!("cleartext auth flow override is enabled, proceeding"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let auth_outcome = AuthFlow::new(client) .begin(auth::CleartextPassword(secret)) @@ -47,14 +48,15 @@ pub async fn authenticate_cleartext( /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) pub async fn password_hack_no_authentication( + ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, -) -> auth::Result>> { +) -> auth::Result { warn!("project not specified, resorting to the password hack auth flow"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) @@ -71,6 +73,6 @@ pub async fn password_hack_no_authentication( options: info.options, endpoint: payload.endpoint, }, - keys: payload.password, + keys: ComputeCredentialKeys::Password(payload.password), }) } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index c71637dd1a..bf9ebf4c18 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -61,6 +61,8 @@ pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { + ctx.set_auth_method(crate::context::AuthMethod::Web); + // registering waiter can fail if we get unlucky with rng. // just try again. let (psql_session_id, waiter) = loop { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d32609e44c..d318b3be54 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint { // record the values if we have them ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_user(user.clone()); + if let Some(dbname) = params.get("database") { + ctx.set_dbname(dbname.into()); + } // Project name might be passed via PG's command-line options. let endpoint_option = params diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c2783e236c..dce73138c6 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ config::TlsServerEndPoint, console::AuthSecret, + context::RequestMonitoring, sasl, scram, stream::{PqStream, Stream}, }; +use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; @@ -23,7 +25,7 @@ pub trait AuthMethod { pub struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret); +pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -138,6 +140,11 @@ impl AuthFlow<'_, S, CleartextPassword> { impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result> { + let Scram(secret, ctx) = self.state; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer.pause(); + // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg) @@ -148,9 +155,15 @@ impl AuthFlow<'_, S, Scram<'_>> { return Err(super::AuthError::bad_auth_method(sasl.method)); } + match sasl.method { + SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus) + } + _ => {} + } info!("client chooses {}", sasl.method); - let secret = self.state.0; let outcome = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new( secret, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 8fbcb56758..b3d4fc0411 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,6 +1,8 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; +use proxy::cancellation::CancelMap; +use proxy::cancellation::CancellationHandler; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; use proxy::redis::notifications; +use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -22,6 +25,7 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; @@ -129,6 +133,9 @@ struct ProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] endpoint_rps_limit: Vec, + /// Redis rate limiter max number of requests per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + redis_rps_limit: Vec, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] initial_limit: usize, @@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); + let cancel_map = CancelMap::default(); + let redis_publisher = match &args.redis_notifications { + Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + url, + args.region.clone(), + &config.redis_rps_limit, + )?))), + None => None, + }; + let cancellation_handler = Arc::new(CancellationHandler::new( + cancel_map.clone(), + redis_publisher, + )); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) @@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); // TODO: rename the argument to something like serverless. @@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> { serverless_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); } @@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(url) = args.redis_notifications { info!("Starting redis notifications listener ({url})"); - maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + maintenance_tasks.spawn(notifications::task_main( + url.to_owned(), + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); } maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } @@ -383,7 +410,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url), ()) } }; let http_config = HttpConfig { @@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let mut redis_rps_limit = args.redis_rps_limit.clone(); + RateBucketInfo::validate(&mut redis_rps_limit)?; let config = Box::leak(Box::new(ProxyConfig { tls_config, @@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, endpoint_rps_limit, + redis_rps_limit, handshake_timeout: args.handshake_timeout, // TODO: add this argument region: args.region.clone(), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index fe614628d8..93a77bc4ae 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,16 +1,28 @@ +use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; use tokio::net::TcpStream; +use tokio::sync::Mutex; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use uuid::Uuid; -use crate::error::ReportableError; +use crate::{ + error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, + redis::publisher::RedisPublisherClient, +}; + +pub type CancelMap = Arc>>; /// Enables serving `CancelRequest`s. -#[derive(Default)] -pub struct CancelMap(DashMap>); +/// +/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler { + map: CancelMap, + redis_client: Option>>, +} #[derive(Debug, Error)] pub enum CancelError { @@ -32,15 +44,43 @@ impl ReportableError for CancelError { } } -impl CancelMap { +impl CancellationHandler { + pub fn new(map: CancelMap, redis_client: Option>>) -> Self { + Self { map, redis_client } + } /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + let from = "from_client"; // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); + if let Some(redis_client) = &self.redis_client { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + info!("publishing cancellation key to Redis"); + match redis_client.lock().await.try_publish(key, session_id).await { + Ok(()) => { + info!("cancellation key successfuly published to Redis"); + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + } return Ok(()); }; - + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await } @@ -57,7 +97,7 @@ impl CancelMap { // Random key collisions are unlikely to happen here, but they're still possible, // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { + match self.map.entry(key) { dashmap::mapref::entry::Entry::Occupied(_) => continue, dashmap::mapref::entry::Entry::Vacant(e) => { e.insert(None); @@ -69,18 +109,46 @@ impl CancelMap { info!("registered new query cancellation key {key}"); Session { key, - cancel_map: self, + cancellation_handler: self, } } #[cfg(test)] fn contains(&self, session: &Session) -> bool { - self.0.contains_key(&session.key) + self.map.contains_key(&session.key) } #[cfg(test)] fn is_empty(&self) -> bool { - self.0.is_empty() + self.map.is_empty() + } +} + +#[async_trait] +pub trait NotificationsCancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +} + +#[async_trait] +impl NotificationsCancellationHandler for CancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { + let from = "from_redis"; + let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); + match cancel_closure { + Some(cancel_closure) => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + cancel_closure.try_cancel_query().await + } + None => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + tracing::warn!("query cancellation key not found: {key}"); + Ok(()) + } + } } } @@ -115,7 +183,7 @@ pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: Arc, + cancellation_handler: Arc, } impl Session { @@ -123,7 +191,9 @@ impl Session { /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); - self.cancel_map.0.insert(self.key, Some(cancel_closure)); + self.cancellation_handler + .map + .insert(self.key, Some(cancel_closure)); self.key } @@ -131,7 +201,7 @@ impl Session { impl Drop for Session { fn drop(&mut self) { - self.cancel_map.0.remove(&self.key); + self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); } } @@ -142,13 +212,16 @@ mod tests { #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancel_map: Arc = Default::default(); + let cancellation_handler = Arc::new(CancellationHandler { + map: CancelMap::default(), + redis_client: None, + }); - let session = cancel_map.clone().get_session(); - assert!(cancel_map.contains(&session)); + let session = cancellation_handler.clone().get_session(); + assert!(cancellation_handler.contains(&session)); drop(session); // Check that the session has been dropped. - assert!(cancel_map.is_empty()); + assert!(cancellation_handler.is_empty()); Ok(()) } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 83940d80ec..b61c1fb9ef 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,7 +1,7 @@ use crate::{ auth::parse_endpoint_param, cancellation::CancelClosure, - console::errors::WakeComputeError, + console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, error::{ReportableError, UserFacingError}, metrics::NUM_DB_CONNECTIONS_GAUGE, @@ -93,7 +93,7 @@ impl ConnCfg { } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: &Self) { + pub fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -253,6 +253,8 @@ pub struct PostgresConnection { pub params: std::collections::HashMap, /// Query cancellation token. pub cancel_closure: CancelClosure, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, _guage: IntCounterPairGuard, } @@ -263,6 +265,7 @@ impl ConnCfg { &self, ctx: &mut RequestMonitoring, allow_self_signed_compute: bool, + aux: MetricsAuxInfo, timeout: Duration, ) -> Result { let (socket_addr, stream, host) = self.connect_raw(timeout).await?; @@ -297,6 +300,7 @@ impl ConnCfg { stream, params, cancel_closure, + aux, _guage: NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 31c9228b35..9f276c3c24 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -13,7 +13,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, ()>, + pub auth_backend: auth::BackendType<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, @@ -21,6 +21,7 @@ pub struct ProxyConfig { pub require_client_ip: bool, pub disable_ip_check_for_http: bool, pub endpoint_rps_limit: Vec, + pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index e5cad42753..640444d14e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,10 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::{backend::ComputeUserInfo, IpPattern}, + auth::{ + backend::{ComputeCredentialKeys, ComputeUserInfo}, + IpPattern, + }, cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, config::{CacheOptions, ProjectInfoCacheOptions}, @@ -261,6 +264,34 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } +impl NodeInfo { + pub async fn connect( + &self, + ctx: &mut RequestMonitoring, + timeout: Duration, + ) -> Result { + self.config + .connect( + ctx, + self.allow_self_signed_compute, + self.aux.clone(), + timeout, + ) + .await + } + pub fn reuse_settings(&mut self, other: Self) { + self.allow_self_signed_compute = other.allow_self_signed_compute; + self.config.reuse_password(other.config); + } + + pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + match keys { + ComputeCredentialKeys::Password(password) => self.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + }; + } +} + pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 79a04f255d..0579ef6fc4 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -176,9 +176,7 @@ impl super::Api for Api { _ctx: &mut RequestMonitoring, _user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute() - .map_ok(CachedNodeInfo::new_uncached) - .await + self.do_wake_compute().map_ok(Cached::new_uncached).await } } diff --git a/proxy/src/context.rs b/proxy/src/context.rs index d2bf3f68d3..0cea53ae63 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -11,7 +11,7 @@ use crate::{ console::messages::MetricsAuxInfo, error::ErrorKind, metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, - BranchId, EndpointId, ProjectId, RoleName, + BranchId, DbName, EndpointId, ProjectId, RoleName, }; pub mod parquet; @@ -34,9 +34,11 @@ pub struct RequestMonitoring { project: Option, branch: Option, endpoint_id: Option, + dbname: Option, user: Option, application: Option, error_kind: Option, + pub(crate) auth_method: Option, success: bool, // extra @@ -45,6 +47,15 @@ pub struct RequestMonitoring { pub latency_timer: LatencyTimer, } +#[derive(Clone, Debug)] +pub enum AuthMethod { + // aka link aka passwordless + Web, + ScramSha256, + ScramSha256Plus, + Cleartext, +} + impl RequestMonitoring { pub fn new( session_id: Uuid, @@ -62,9 +73,11 @@ impl RequestMonitoring { project: None, branch: None, endpoint_id: None, + dbname: None, user: None, application: None, error_kind: None, + auth_method: None, success: false, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -106,10 +119,18 @@ impl RequestMonitoring { self.application = app.or_else(|| self.application.clone()); } + pub fn set_dbname(&mut self, dbname: DbName) { + self.dbname = Some(dbname); + } + pub fn set_user(&mut self, user: RoleName) { self.user = Some(user); } + pub fn set_auth_method(&mut self, auth_method: AuthMethod) { + self.auth_method = Some(auth_method); + } + pub fn set_error_kind(&mut self, kind: ErrorKind) { ERROR_BY_KIND .with_label_values(&[kind.to_metric_label()]) diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 0fe46915bc..ad22829183 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -84,8 +84,10 @@ struct RequestData { username: Option, application_name: Option, endpoint_id: Option, + database: Option, project: Option, branch: Option, + auth_method: Option<&'static str>, error: Option<&'static str>, /// Success is counted if we form a HTTP response with sql rows inside /// Or if we make it to proxy_pass @@ -104,8 +106,15 @@ impl From for RequestData { username: value.user.as_deref().map(String::from), application_name: value.application.as_deref().map(String::from), endpoint_id: value.endpoint_id.as_deref().map(String::from), + database: value.dbname.as_deref().map(String::from), project: value.project.as_deref().map(String::from), branch: value.branch.as_deref().map(String::from), + auth_method: value.auth_method.as_ref().map(|x| match x { + super::AuthMethod::Web => "web", + super::AuthMethod::ScramSha256 => "scram_sha_256", + super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus", + super::AuthMethod::Cleartext => "cleartext", + }), protocol: value.protocol, region: value.region, error: value.error_kind.as_ref().map(|e| e.to_metric_label()), @@ -431,8 +440,10 @@ mod tests { application_name: Some("test".to_owned()), username: Some(hex::encode(rng.gen::<[u8; 4]>())), endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())), + database: Some(hex::encode(rng.gen::<[u8; 16]>())), project: Some(hex::encode(rng.gen::<[u8; 16]>())), branch: Some(hex::encode(rng.gen::<[u8; 16]>())), + auth_method: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], region: "us-east-1", error: None, @@ -505,15 +516,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -543,11 +554,11 @@ mod tests { assert_eq!( file_stats, [ - (1028637, 5, 10000), - (1031969, 5, 10000), - (1019900, 5, 10000), - (1020365, 5, 10000), - (1025010, 5, 10000) + (1219459, 5, 10000), + (1225609, 5, 10000), + (1227403, 5, 10000), + (1226765, 5, 10000), + (1218043, 5, 10000) ], ); @@ -579,11 +590,11 @@ mod tests { assert_eq!( file_stats, [ - (1210770, 6, 12000), - (1211036, 6, 12000), - (1210990, 6, 12000), - (1210861, 6, 12000), - (202073, 1, 2000) + (1205106, 5, 10000), + (1204837, 5, 10000), + (1205130, 5, 10000), + (1205118, 5, 10000), + (1205373, 5, 10000) ], ); @@ -608,15 +619,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -653,7 +664,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)], + [(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)], ); tmpdir.close().unwrap(); diff --git a/proxy/src/error.rs b/proxy/src/error.rs index eafe92bf48..69fe1ebc12 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, @@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed { ErrorKind::RateLimit } } + +impl ReportableError for tokio_postgres::error::Error { + fn get_error_kind(&self) -> ErrorKind { + if self.as_db_error().is_some() { + ErrorKind::Postgres + } else { + ErrorKind::Compute + } + } +} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index ccf89f9b05..66031f5eb2 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { .unwrap() }); +pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_cancellation_requests_total", + "Number of cancellation requests (per found/not_found).", + &["source", "kind"], + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started @@ -200,8 +209,9 @@ impl LatencyTimer { pub fn success(&mut self) { // stop the stopwatch and record the time that we have accumulated - let start = self.start.take().expect("latency timer should be started"); - self.accumulated += start.elapsed(); + if let Some(start) = self.start.take() { + self.accumulated += start.elapsed(); + } // success self.outcome = "success"; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 50e22ec72a..8a9445303a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,6 +2,7 @@ mod tests; pub mod connect_compute; +mod copy_bidirectional; pub mod handshake; pub mod passthrough; pub mod retry; @@ -9,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancelMap}, + cancellation::{self, CancellationHandler}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -61,6 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -71,7 +73,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancel_map = Arc::new(CancelMap::default()); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -79,7 +80,7 @@ pub async fn task_main( let (socket, peer_addr) = accept_result?; let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); + let cancellation_handler = Arc::clone(&cancellation_handler); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let session_span = info_span!( @@ -112,7 +113,7 @@ pub async fn task_main( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, socket, ClientMode::Tcp, endpoint_rate_limiter, @@ -162,14 +163,14 @@ pub enum ClientMode { /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - fn allow_cleartext(&self) -> bool { + pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -226,7 +227,7 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancel_map: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, @@ -252,8 +253,8 @@ pub async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancel_map - .cancel_session(cancel_key_data) + return Ok(cancellation_handler + .cancel_session(cancel_key_data, ctx.session_id) .await .map(|()| None)?) } @@ -286,7 +287,7 @@ pub async fn handle_client( } let user = user_info.get_user().to_owned(); - let (mut node_info, user_info) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -305,19 +306,16 @@ pub async fn handle_client( } }; - node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config); - - let aux = node_info.aux.clone(); let mut node = connect_to_compute( ctx, &TcpMechanism { params: ¶ms }, - node_info, &user_info, + mode.allow_self_signed_compute(config), ) .or_else(|e| stream.throw_error(e)) .await?; - let session = cancel_map.get_session(); + let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the @@ -329,10 +327,11 @@ pub async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, + aux: node.aux.clone(), compute: node, - aux, req: _request_gauge, conn: _client_gauge, + cancel: session, })) } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index b9346aa743..c76e2ff6d9 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,8 +1,9 @@ use crate::{ - auth, + auth::backend::ComputeCredentialKeys, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError}, + console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + error::ReportableError, metrics::NUM_CONNECTION_FAILURES, proxy::{ retry::{retry_after, ShouldRetry}, @@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg { +pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg }; NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); - node_info.invalidate().config + node_info.invalidate() } #[async_trait] pub trait ConnectMechanism { type Connection; - type ConnectError; + type ConnectError: ReportableError; type Error: From; async fn connect_once( &self, @@ -49,6 +50,16 @@ pub trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } +#[async_trait] +pub trait ComputeConnectBackend { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result; + + fn get_keys(&self) -> Option<&ComputeCredentialKeys>; +} + pub struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. pub params: &'a StartupMessageParams, @@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await + node_info.connect(ctx, timeout).await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, - mut node_info: console::CachedNodeInfo, - user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, + user_info: &B, + allow_self_signed_compute: bool, ) -> Result where M::ConnectError: ShouldRetry + std::fmt::Debug, M::Error: From, { + let mut num_retries = 0; + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + if let Some(keys) = user_info.get_keys() { + node_info.set_keys(keys); + } + node_info.allow_self_signed_compute = allow_self_signed_compute; + // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); // try once @@ -108,28 +122,30 @@ where error!(error = ?err, "could not connect to compute node"); - let mut num_retries = 1; - - match user_info { - auth::BackendType::Console(api, info) => { - // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node - info!("compute node's state has likely changed; requesting a wake-up"); - - ctx.latency_timer.cache_miss(); - let config = invalidate_cache(node_info); - node_info = wake_compute(&mut num_retries, ctx, api, info).await?; - - node_info.config.reuse_password(&config); - mechanism.update_connect_config(&mut node_info.config); + let node_info = if !node_info.cached() { + // If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry. + // Do not need to retrieve a new node_info, just return the old one. + if !err.should_retry(num_retries) { + return Err(err.into()); } - // nothing to do? - auth::BackendType::Link(_) => {} + node_info + } else { + // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node + info!("compute node's state has likely changed; requesting a wake-up"); + ctx.latency_timer.cache_miss(); + let old_node_info = invalidate_cache(node_info); + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + node_info.reuse_settings(old_node_info); + + mechanism.update_connect_config(&mut node_info.config); + node_info }; // now that we have a new node, try connect to it repeatedly. // this can error for a few reasons, for instance: // * DNS connection settings haven't quite propagated yet info!("wake_compute success. attempting to connect"); + num_retries = 1; loop { match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs new file mode 100644 index 0000000000..2ecc1151da --- /dev/null +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -0,0 +1,256 @@ +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +#[derive(Debug)] +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +fn transfer_one_direction( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +pub(super) async fn copy_bidirectional( + a: &mut A, + b: &mut B, +) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut a_to_b = TransferState::Running(CopyBuffer::new()); + let mut b_to_a = TransferState::Running(CopyBuffer::new()); + + poll_fn(|cx| { + let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + + // Early termination checks + if let TransferState::Done(_) = a_to_b { + if let TransferState::Running(buf) = &b_to_a { + // Initiate shutdown + b_to_a = TransferState::ShuttingDown(buf.amt); + b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + } + } + if let TransferState::Done(_) = b_to_a { + if let TransferState::Running(buf) = &a_to_b { + // Initiate shutdown + a_to_b = TransferState::ShuttingDown(buf.amt); + a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + } + } + + // It is not a problem if ready! returns early ... (comment remains the same) + let a_to_b = ready!(a_to_b_result); + let b_to_a = ready!(b_to_a_result); + + Poll::Ready(Ok((a_to_b, b_to_a))) + }) + .await +} + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +const DEFAULT_BUF_SIZE: usize = 8 * 1024; + +impl CopyBuffer { + pub(super) fn new() -> Self { + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(), + } + } + + fn poll_fill_buf( + &mut self, + cx: &mut Context<'_>, + reader: Pin<&mut R>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(())) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + self.pos = 0; + self.cap = 0; + + match self.poll_fill_buf(cx, reader.as_mut()) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn test_early_termination_a_to_d() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + a_mock.write_all(b"hello").await.unwrap(); + a_mock.shutdown().await.unwrap(); + d_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(a_to_d_count, 5); // 'hello' was transferred + assert!(d_to_a_count <= 8); // response only partially transferred or not at all + } + + #[tokio::test] + async fn test_early_termination_d_to_a() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + d_mock.write_all(b"hello").await.unwrap(); + d_mock.shutdown().await.unwrap(); + a_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(d_to_a_count, 5); // 'hello' was transferred + assert!(a_to_d_count <= 8); // response only partially transferred or not at all + } +} diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b7018c6fb5..73c170fc0b 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,4 +1,5 @@ use crate::{ + cancellation, compute::PostgresConnection, console::messages::MetricsAuxInfo, metrics::NUM_BYTES_PROXIED_COUNTER, @@ -45,7 +46,7 @@ pub async fn proxy_pass( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?; Ok(()) } @@ -57,6 +58,7 @@ pub struct ProxyPassthrough { pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, + pub cancel: cancellation::Session, } impl ProxyPassthrough { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3e961afb41..1a01f32339 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -2,13 +2,19 @@ mod mitm; +use std::time::Duration; + use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; +use crate::auth::backend::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, +}; use crate::config::CertResolver; +use crate::console::caches::NodeInfoCache; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::error::ErrorKind; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; @@ -144,7 +150,7 @@ impl TestAuth for Scram { stream: &mut PqStream>, ) -> anyhow::Result<()> { let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0)) + .begin(auth::Scram(&self.0, &mut RequestMonitoring::test())) .await? .authenticate() .await?; @@ -375,6 +381,7 @@ enum ConnectAction { struct TestConnectMechanism { counter: Arc>, sequence: Vec, + cache: &'static NodeInfoCache, } impl TestConnectMechanism { @@ -393,6 +400,12 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, + cache: Box::leak(Box::new(NodeInfoCache::new( + "test", + 1, + Duration::from_secs(100), + false, + ))), } } } @@ -403,6 +416,13 @@ struct TestConnection; #[derive(Debug)] struct TestConnectError { retryable: bool, + kind: crate::error::ErrorKind, +} + +impl ReportableError for TestConnectError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + self.kind + } } impl std::fmt::Display for TestConnectError { @@ -436,8 +456,14 @@ impl ConnectMechanism for TestConnectMechanism { *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { retryable: true }), - ConnectAction::Fail => Err(TestConnectError { retryable: false }), + ConnectAction::Retry => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Compute, + }), + ConnectAction::Fail => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Compute, + }), x => panic!("expecting action {:?}, connect is called instead", x), } } @@ -451,7 +477,7 @@ impl TestBackend for TestConnectMechanism { let action = self.sequence[*counter]; *counter += 1; match action { - ConnectAction::Wake => Ok(helper_create_cached_node_info()), + ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console { status: http::StatusCode::FORBIDDEN, @@ -483,37 +509,41 @@ impl TestBackend for TestConnectMechanism { } } -fn helper_create_cached_node_info() -> CachedNodeInfo { +fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { config: compute::ConnCfg::new(), aux: Default::default(), allow_self_signed_compute: false, }; - CachedNodeInfo::new_uncached(node) + let (_, node) = cache.insert("key".into(), node); + node } fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { - let cache = helper_create_cached_node_info(); +) -> auth::BackendType<'static, ComputeCredentials, &()> { let user_info = auth::BackendType::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), - ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), + ComputeCredentials { + info: ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + keys: ComputeCredentialKeys::Password("password".into()), }, ); - (cache, user_info) + user_info } #[tokio::test] async fn connect_to_compute_success() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -521,11 +551,12 @@ async fn connect_to_compute_success() { #[tokio::test] async fn connect_to_compute_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -534,11 +565,12 @@ async fn connect_to_compute_retry() { /// Test that we don't retry if the error is not retryable. #[tokio::test] async fn connect_to_compute_non_retry_1() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -547,11 +579,12 @@ async fn connect_to_compute_non_retry_1() { /// Even for non-retryable errors, we should retry at least once. #[tokio::test] async fn connect_to_compute_non_retry_2() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -560,15 +593,16 @@ async fn connect_to_compute_non_retry_2() { /// Retry for at most `NUM_RETRIES_CONNECT` times. #[tokio::test] async fn connect_to_compute_non_retry_3() { + let _ = env_logger::try_init(); assert_eq!(NUM_RETRIES_CONNECT, 16); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![ - Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, - Retry, Retry, Retry, Retry, /* the 17th time */ Retry, + Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, + Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry, ]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -577,11 +611,12 @@ async fn connect_to_compute_non_retry_3() { /// Should retry wake compute. #[tokio::test] async fn wake_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -590,11 +625,12 @@ async fn wake_retry() { /// Wake failed with a non-retryable error. #[tokio::test] async fn wake_non_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 925727bdab..2c593451b4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,4 @@ -use crate::auth::backend::ComputeUserInfo; -use crate::console::{ - errors::WakeComputeError, - provider::{CachedNodeInfo, ConsoleBackend}, - Api, -}; +use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; use crate::proxy::retry::retry_after; @@ -11,17 +6,16 @@ use hyper::StatusCode; use std::ops::ControlFlow; use tracing::{error, warn}; +use super::connect_compute::ComputeConnectBackend; use super::retry::ShouldRetry; -/// wake a compute (or retrieve an existing compute session from cache) -pub async fn wake_compute( +pub async fn wake_compute( num_retries: &mut u32, ctx: &mut RequestMonitoring, - api: &ConsoleBackend, - info: &ComputeUserInfo, + api: &B, ) -> Result { loop { - let wake_res = api.wake_compute(ctx, info).await; + let wake_res = api.wake_compute(ctx).await; match handle_try_wake(wake_res, *num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index b26386d159..f0da4ead23 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{EndpointRateLimiter, RateBucketInfo}; +pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index cbae72711c..3181060e2f 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -22,6 +22,44 @@ use super::{ RateLimiterConfig, }; +pub struct RedisRateLimiter { + data: Vec, + info: &'static [RateBucketInfo], +} + +impl RedisRateLimiter { + pub fn new(info: &'static [RateBucketInfo]) -> Self { + Self { + data: vec![ + RateBucket { + start: Instant::now(), + count: 0, + }; + info.len() + ], + info, + } + } + + /// Check that number of connections is below `max_rps` rps. + pub fn check(&mut self) -> bool { + let now = Instant::now(); + + let should_allow_request = self + .data + .iter_mut() + .zip(self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now)); + + if should_allow_request { + // only increment the bucket counts if the request will actually be accepted + self.data.iter_mut().for_each(RateBucket::inc); + } + + should_allow_request + } +} + // Simple per-endpoint rate limiter. // // Check that number of connections to the endpoint is below `max_rps` rps. diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index c2a91bed97..35d6db074e 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1 +1,2 @@ pub mod notifications; +pub mod publisher; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 158884aa17..b8297a206c 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -1,38 +1,44 @@ use std::{convert::Infallible, sync::Arc}; use futures::StreamExt; +use pq_proto::CancelKeyData; use redis::aio::PubSub; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::{ cache::project_info::ProjectInfoCache, + cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, }; -const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct ConsoleRedisClient { +struct RedisConsumerClient { client: redis::Client, } -impl ConsoleRedisClient { +impl RedisConsumerClient { pub fn new(url: &str) -> anyhow::Result { let client = redis::Client::open(url)?; Ok(Self { client }) } async fn try_connect(&self) -> anyhow::Result { let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CHANNEL_NAME}`"); - conn.subscribe(CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; Ok(conn) } } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -enum Notification { +pub(crate) enum Notification { #[serde( rename = "/allowed_ips_updated", deserialize_with = "deserialize_json_string" @@ -45,16 +51,25 @@ enum Notification { deserialize_with = "deserialize_json_string" )] PasswordUpdate { password_update: PasswordUpdate }, + #[serde(rename = "/cancel_session")] + Cancel(CancelSession), } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct AllowedIpsUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedIpsUpdate { project_id: ProjectIdInt, } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct PasswordUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct PasswordUpdate { project_id: ProjectIdInt, role_name: RoleNameInt, } +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct CancelSession { + pub region_id: Option, + pub cancel_key_data: CancelKeyData, + pub session_id: Uuid, +} + fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -64,6 +79,88 @@ where serde_json::from_str(&s).map_err(::custom) } +struct MessageHandler< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, +> { + cache: Arc, + cancellation_handler: Arc, + region_id: String, +} + +impl< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, + > MessageHandler +{ + pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { + Self { + cache, + cancellation_handler, + region_id, + } + } + pub fn disable_ttl(&self) { + self.cache.disable_ttl(); + } + pub fn enable_ttl(&self) { + self.cache.enable_ttl(); + } + #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] + async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { + use Notification::*; + let payload: String = msg.get_payload()?; + tracing::debug!(?payload, "received a message payload"); + + let msg: Notification = match serde_json::from_str(&payload) { + Ok(msg) => msg, + Err(e) => { + tracing::error!("broken message: {e}"); + return Ok(()); + } + }; + tracing::debug!(?msg, "received a message"); + match msg { + Cancel(cancel_session) => { + tracing::Span::current().record( + "session_id", + &tracing::field::display(cancel_session.session_id), + ); + if let Some(cancel_region) = cancel_session.region_id { + // If the message is not for this region, ignore it. + if cancel_region != self.region_id { + return Ok(()); + } + } + // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. + match self + .cancellation_handler + .cancel_session_no_publish(cancel_session.cancel_key_data) + .await + { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to cancel session: {e}"); + } + } + } + _ => { + invalidate_cache(self.cache.clone(), msg.clone()); + // It might happen that the invalid entry is on the way to be cached. + // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. + // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. + let cache = self.cache.clone(); + tokio::spawn(async move { + tokio::time::sleep(INVALIDATION_LAG).await; + invalidate_cache(cache, msg); + }); + } + } + + Ok(()) + } +} + fn invalidate_cache(cache: Arc, msg: Notification) { use Notification::*; match msg { @@ -74,50 +171,33 @@ fn invalidate_cache(cache: Arc, msg: Notification) { password_update.project_id, password_update.role_name, ), + Cancel(_) => unreachable!("cancel message should be handled separately"), } } -#[tracing::instrument(skip(cache))] -fn handle_message(msg: redis::Msg, cache: Arc) -> anyhow::Result<()> -where - C: ProjectInfoCache + Send + Sync + 'static, -{ - let payload: String = msg.get_payload()?; - tracing::debug!(?payload, "received a message payload"); - - let msg: Notification = match serde_json::from_str(&payload) { - Ok(msg) => msg, - Err(e) => { - tracing::error!("broken message: {e}"); - return Ok(()); - } - }; - tracing::debug!(?msg, "received a message"); - invalidate_cache(cache.clone(), msg.clone()); - // It might happen that the invalid entry is on the way to be cached. - // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. - // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. - tokio::spawn(async move { - tokio::time::sleep(INVALIDATION_LAG).await; - invalidate_cache(cache, msg.clone()); - }); - - Ok(()) -} - /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] -pub async fn task_main(url: String, cache: Arc) -> anyhow::Result +pub async fn task_main( + url: String, + cache: Arc, + cancel_map: CancelMap, + region_id: String, +) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { cache.enable_ttl(); + let handler = MessageHandler::new( + cache, + Arc::new(CancellationHandler::new(cancel_map, None)), + region_id, + ); loop { - let redis = ConsoleRedisClient::new(&url)?; + let redis = RedisConsumerClient::new(&url)?; let conn = match redis.try_connect().await { Ok(conn) => { - cache.disable_ttl(); + handler.disable_ttl(); conn } Err(e) => { @@ -130,7 +210,7 @@ where }; let mut stream = conn.into_on_message(); while let Some(msg) = stream.next().await { - match handle_message(msg, cache.clone()) { + match handler.handle_message(msg).await { Ok(()) => {} Err(e) => { tracing::error!("failed to handle message: {e}, will try to reconnect"); @@ -138,7 +218,7 @@ where } } } - cache.enable_ttl(); + handler.enable_ttl(); } } @@ -198,6 +278,33 @@ mod tests { } ); + Ok(()) + } + #[test] + fn parse_cancel_session() -> anyhow::Result<()> { + let cancel_key_data = CancelKeyData { + backend_pid: 42, + cancel_key: 41, + }; + let uuid = uuid::Uuid::new_v4(); + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: None, + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result); + + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: Some("region".to_string()), + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result,); + Ok(()) } } diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs new file mode 100644 index 0000000000..f85593afdd --- /dev/null +++ b/proxy/src/redis/publisher.rs @@ -0,0 +1,80 @@ +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; + +pub struct RedisPublisherClient { + client: redis::Client, + publisher: Option, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + url: &str, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { + client, + publisher: None, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + pub async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + self.publisher = None; + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let conn = self + .publisher + .as_mut() + .ok_or_else(|| anyhow::anyhow!("not connected"))?; + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + conn.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.get_async_connection().await { + Ok(conn) => { + self.publisher = Some(conn); + } + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e.into()); + } + } + Ok(()) + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a20600b94a..ee3e91495b 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancelMap, config::ProxyConfig}; +use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use futures::StreamExt; use hyper::{ server::{ @@ -50,6 +50,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -115,7 +116,7 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - + let cancellation_handler = cancellation_handler.clone(); async move { let peer_addr = match client_addr { Some(addr) => addr, @@ -127,9 +128,9 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); async move { - let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); request_handler( @@ -137,7 +138,7 @@ pub async fn task_main( config, backend, ws_connections, - cancel_map, + cancellation_handler, session_id, peer_addr.ip(), endpoint_rate_limiter, @@ -205,7 +206,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancel_map: Arc, + cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, endpoint_rate_limiter: Arc, @@ -232,7 +233,7 @@ async fn request_handler( config, ctx, websocket, - cancel_map, + cancellation_handler, host, endpoint_rate_limiter, ) diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8285da68d7..6f93f86d5f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,10 +1,10 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use tracing::info; +use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, console::{ @@ -15,7 +15,7 @@ use crate::{ proxy::connect_compute::ConnectMechanism, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME}; +use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; pub struct PoolingBackend { pub pool: Arc>, @@ -27,7 +27,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> Result { + ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; @@ -49,13 +49,17 @@ impl PoolingBackend { }; let auth_outcome = crate::auth::validate_password_and_exchange(&conn_info.password, secret)?; - match auth_outcome { + let res = match auth_outcome { crate::sasl::Outcome::Success(key) => Ok(key), crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); Err(AuthError::auth_failed(&*conn_info.user_info.user)) } - } + }; + res.map(|key| ComputeCredentials { + info: user_info, + keys: key, + }) } // Wake up the destination if needed. Code here is a bit involved because @@ -66,7 +70,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: ConnInfo, - keys: ComputeCredentialKeys, + keys: ComputeCredentials, force_new: bool, ) -> Result, HttpConnError> { let maybe_client = if !force_new { @@ -81,27 +85,9 @@ impl PoolingBackend { return Ok(client); } let conn_id = uuid::Uuid::new_v4(); + tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - ctx.set_application(Some(APP_NAME)); - let backend = self - .config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - let mut node_info = backend - .wake_compute(ctx) - .await? - .ok_or(HttpConnError::NoComputeInfo)?; - - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node_info.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), - }; - - ctx.set_project(node_info.aux.clone()); - + let backend = self.config.auth_backend.as_ref().map(|_| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -109,8 +95,8 @@ impl PoolingBackend { conn_info, pool: self.pool.clone(), }, - node_info, &backend, + false, // do not allow self signed compute for http flow ) .await } @@ -129,8 +115,6 @@ pub enum HttpConnError { AuthError(#[from] AuthError), #[error("wake_compute returned error")] WakeCompute(#[from] WakeComputeError), - #[error("wake_compute returned nothing")] - NoComputeInfo, } struct TokioMechanism { diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index f4e5b145c5..53e7c1c2ee 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard; use parking_lot::RwLock; use rand::Rng; use smallvec::SmallVec; -use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; use std::{ fmt, @@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; -pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); - #[derive(Debug, Clone)] pub struct ConnInfo { pub user_info: ComputeUserInfo, @@ -379,12 +376,13 @@ impl GlobalConnPool { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); } else { - info!("pool: reusing connection '{conn_info}'"); - client.session.send(ctx.session_id)?; + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); tracing::Span::current().record( "pid", &tracing::field::display(client.inner.get_process_id()), ); + info!("pool: reusing connection '{conn_info}'"); + client.session.send(ctx.session_id)?; ctx.latency_timer.pool_hit(); ctx.latency_timer.success(); return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); @@ -577,7 +575,6 @@ pub struct Client { } pub struct Discard<'a, C: ClientInnerExt> { - conn_id: uuid::Uuid, conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } @@ -603,14 +600,7 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - ( - &mut inner.inner, - Discard { - pool, - conn_info, - conn_id: inner.conn_id, - }, - ) + (&mut inner.inner, Discard { pool, conn_info }) } pub fn check_idle(&mut self, status: ReadyForQueryStatus) { @@ -625,13 +615,13 @@ impl Discard<'_, C> { pub fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle") + info!("pool: throwing away connection '{conn_info}' because connection is not idle") } } pub fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 54424360c4..ecb72abe73 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -36,6 +36,8 @@ use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::serverless::backend::HttpConnError; +use crate::DbName; use crate::RoleName; use super::backend::PoolingBackend; @@ -117,6 +119,9 @@ fn get_conn_info( headers: &HeaderMap, tls: &TlsConfig, ) -> Result { + // HTTP only uses cleartext (for now and likely always) + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + let connection_string = headers .get("Neon-Connection-String") .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? @@ -134,7 +139,8 @@ fn get_conn_info( .path_segments() .ok_or(ConnInfoError::MissingDbName)?; - let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; + let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into(); + ctx.set_dbname(dbname.clone()); let username = RoleName::from(urlencoding::decode(connection_url.username())?); if username.is_empty() { @@ -174,7 +180,7 @@ fn get_conn_info( Ok(ConnInfo { user_info, - dbname: dbname.into(), + dbname, password: match password { std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(), @@ -300,7 +306,14 @@ pub async fn handle( Ok(response) } -#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] +#[instrument( + name = "sql-over-http", + skip_all, + fields( + pid = tracing::field::Empty, + conn_id = tracing::field::Empty + ) +)] async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, @@ -354,12 +367,10 @@ async fn handle_inner( let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE); let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE); - let paused = ctx.latency_timer.pause(); let request_content_length = match request.body().size_hint().upper() { Some(v) => v, None => MAX_REQUEST_SIZE + 1, }; - drop(paused); info!(request_content_length, "request size in bytes"); HTTP_CONTENT_LENGTH.observe(request_content_length as f64); @@ -375,15 +386,20 @@ async fn handle_inner( let body = hyper::body::to_bytes(request.into_body()) .await .map_err(anyhow::Error::from)?; + info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly }; let authenticate_and_connect = async { let keys = backend.authenticate(ctx, &conn_info).await?; - backend + let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await + .await?; + // not strictly necessary to mark success here, + // but it's just insurance for if we forget it somewhere else + ctx.latency_timer.success(); + Ok::<_, HttpConnError>(client) }; // Run both operations in parallel @@ -415,6 +431,7 @@ async fn handle_inner( results } Payload::Batch(statements) => { + info!("starting transaction"); let (inner, mut discard) = client.inner(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = txn_isolation_level { @@ -444,6 +461,7 @@ async fn handle_inner( .await { Ok(results) => { + info!("commit"); let status = transaction.commit().await.map_err(|e| { // if we cannot commit - for now don't return connection to pool // TODO: get a query status from the error @@ -454,6 +472,7 @@ async fn handle_inner( results } Err(err) => { + info!("rollback"); let status = transaction.rollback().await.map_err(|e| { // if we cannot rollback - for now don't return connection to pool // TODO: get a query status from the error @@ -528,8 +547,10 @@ async fn query_to_json( raw_output: bool, default_array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { + info!("executing query"); let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; + info!("finished executing query"); // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too @@ -564,6 +585,13 @@ async fn query_to_json( } .and_then(|s| s.parse::().ok()); + info!( + rows = rows.len(), + ?ready, + command_tag, + "finished reading rows" + ); + let mut fields = vec![]; let mut columns = vec![]; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 062dd440b2..24f2bb7e8c 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancelMap, + cancellation::CancellationHandler, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -133,7 +133,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancel_map: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { @@ -141,7 +141,7 @@ pub async fn serve_websocket( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 8e9cc43152..e808fabbe7 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -38,7 +38,7 @@ futures-io = { version = "0.3" } futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } @@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits", "use_std"] }