Merge pull request #6758 from neondatabase/release-proxy-2024-02-14

2024-02-14 Proxy Release
This commit is contained in:
Conrad Ludgate
2024-02-15 09:45:08 +00:00
committed by GitHub
38 changed files with 1109 additions and 358 deletions

7
Cargo.lock generated
View File

@@ -2247,11 +2247,11 @@ dependencies = [
[[package]]
name = "hashlink"
version = "0.8.2"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
dependencies = [
"hashbrown 0.13.2",
"hashbrown 0.14.0",
]
[[package]]
@@ -3936,6 +3936,7 @@ dependencies = [
"pin-project-lite",
"postgres-protocol",
"rand 0.8.5",
"serde",
"thiserror",
"tokio",
"tracing",

View File

@@ -80,7 +80,7 @@ futures-core = "0.3"
futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.1"
hashlink = "0.8.4"
hdrhistogram = "7.5.2"
hex = "0.4"
hex-literal = "0.4"

View File

@@ -13,5 +13,6 @@ rand.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true
workspace_hack.workspace = true

View File

@@ -7,6 +7,7 @@ pub mod framed;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
// re-export for use in utils pageserver_feedback.rs
@@ -123,7 +124,7 @@ impl StartupMessageParams {
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct CancelKeyData {
pub backend_pid: i32,
pub cancel_key: i32,

View File

@@ -36,9 +36,6 @@ pub enum AuthErrorImpl {
#[error(transparent)]
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -119,7 +116,6 @@ impl UserFacingError for AuthError {
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),
@@ -139,7 +135,6 @@ impl ReportableError for AuthError {
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
WakeCompute(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,

View File

@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::AuthSecret;
use crate::console::{AuthSecret, NodeInfo};
use crate::context::RequestMonitoring;
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
use crate::{
@@ -26,7 +26,6 @@ use crate::{
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -56,11 +55,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T> {
pub enum BackendType<'a, T, D> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>),
Link(MaybeOwned<'a, url::ApiUrl>, D),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
impl std::fmt::Display for BackendType<'_, (), ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
@@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
impl<T> BackendType<'_, T> {
impl<T, D> BackendType<'_, T, D> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T> {
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
use BackendType::*;
match self {
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c) => Link(MaybeOwned::Borrowed(c)),
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
}
}
}
impl<'a, T> BackendType<'a, T> {
impl<'a, T, D> BackendType<'a, T, D> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
Link(c) => Link(c),
Link(c, x) => Link(c, x),
}
}
}
impl<'a, T, E> BackendType<'a, Result<T, E>> {
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c) => Ok(Link(c)),
Link(c, x) => Ok(Link(c, x)),
}
}
}
pub struct ComputeCredentials<T> {
pub struct ComputeCredentials {
pub info: ComputeUserInfo,
pub keys: T,
pub keys: ComputeCredentialKeys,
}
#[derive(Debug, Clone)]
@@ -153,7 +151,6 @@ impl ComputeUserInfo {
}
pub enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
@@ -188,19 +185,21 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
ctx.set_endpoint_id(res.info.endpoint.clone());
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
(res.info, Some(res.keys))
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
_ => unreachable!("password hack should return a password"),
};
(res.info, Some(password))
}
Ok(info) => (info, None),
};
@@ -254,7 +253,7 @@ async fn authenticate_with_secret(
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret)?;
let keys = match auth_outcome {
@@ -276,21 +275,22 @@ async fn authenticate_with_secret(
// Perform cleartext auth if we're allowed to do that.
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_) => Some("link".into()),
Link(_, _) => Some("link".into()),
}
}
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => &user_info.user,
Link(_) => "link",
Link(_, _) => "link",
}
}
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
use BackendType::*;
let res = match self {
@@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let compute_credentials =
let credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node =
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
(node, BackendType::Console(api, compute_credentials.info))
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
Link(url, _) => {
info!("performing link authentication");
let node_info = link::authenticate(ctx, &url, client).await?;
let info = link::authenticate(ctx, &url, client).await?;
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
BackendType::Link(url, info)
}
};
@@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
}
}
impl BackendType<'_, ComputeUserInfo> {
impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
@@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(None)),
Link(_, _) => Ok(Cached::new_uncached(None)),
}
}
@@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}

View File

@@ -4,7 +4,7 @@ use crate::{
compute,
config::AuthenticationConfig,
console::AuthSecret,
metrics::LatencyTimer,
context::RequestMonitoring,
sasl,
stream::{PqStream, Stream},
};
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
@@ -27,13 +27,11 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let scram = auth::Scram(&secret, &mut *ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
// pause the timer while we communicate with the client
let _paused = latency_timer.pause();
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");

View File

@@ -4,7 +4,7 @@ use super::{
use crate::{
auth::{self, AuthFlow},
console::AuthSecret,
metrics::LatencyTimer,
context::RequestMonitoring,
sasl,
stream::{self, Stream},
};
@@ -16,15 +16,16 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn authenticate_cleartext(
ctx: &mut RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = latency_timer.pause();
let _paused = ctx.latency_timer.pause();
let auth_outcome = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
@@ -47,14 +48,15 @@ pub async fn authenticate_cleartext(
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = latency_timer.pause();
let _paused = ctx.latency_timer.pause();
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -71,6 +73,6 @@ pub async fn password_hack_no_authentication(
options: info.options,
endpoint: payload.endpoint,
},
keys: payload.password,
keys: ComputeCredentialKeys::Password(payload.password),
})
}

View File

@@ -61,6 +61,8 @@ pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
ctx.set_auth_method(crate::context::AuthMethod::Web);
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {

View File

@@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint {
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
if let Some(dbname) = params.get("database") {
ctx.set_dbname(dbname.into());
}
// Project name might be passed via PG's command-line options.
let endpoint_option = params

View File

@@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -23,7 +25,7 @@ pub trait AuthMethod {
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret);
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
impl AuthMethod for Scram<'_> {
#[inline(always)]
@@ -138,6 +140,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
@@ -148,9 +155,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
match sasl.method {
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,

View File

@@ -1,6 +1,8 @@
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
@@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -22,6 +25,7 @@ use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
@@ -129,6 +133,9 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
@@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
cancel_map.clone(),
redis_publisher,
));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
@@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
// TODO: rename the argument to something like serverless.
@@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
}
@@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
@@ -383,7 +410,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
}
AuthBackend::Link => {
let url = args.uri.parse()?;
auth::BackendType::Link(MaybeOwned::Owned(url))
auth::BackendType::Link(MaybeOwned::Owned(url), ())
}
};
let http_config = HttpConfig {
@@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
@@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),

View File

@@ -1,16 +1,28 @@
use async_trait::async_trait;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;
use crate::error::ReportableError;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
}
#[derive(Debug, Error)]
pub enum CancelError {
@@ -32,15 +44,43 @@ impl ReportableError for CancelError {
}
}
impl CancelMap {
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
@@ -57,7 +97,7 @@ impl CancelMap {
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
match self.map.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => continue,
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
@@ -69,18 +109,46 @@ impl CancelMap {
info!("registered new query cancellation key {key}");
Session {
key,
cancel_map: self,
cancellation_handler: self,
}
}
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.0.contains_key(&session.key)
self.map.contains_key(&session.key)
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.0.is_empty()
self.map.is_empty()
}
}
#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
}
#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
}
}
@@ -115,7 +183,7 @@ pub struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: Arc<CancelMap>,
cancellation_handler: Arc<CancellationHandler>,
}
impl Session {
@@ -123,7 +191,9 @@ impl Session {
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
self.cancellation_handler
.map
.insert(self.key, Some(cancel_closure));
self.key
}
@@ -131,7 +201,7 @@ impl Session {
impl Drop for Session {
fn drop(&mut self) {
self.cancel_map.0.remove(&self.key);
self.cancellation_handler.map.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
@@ -142,13 +212,16 @@ mod tests {
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancel_map: Arc<CancelMap> = Default::default();
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancel_map.is_empty());
assert!(cancellation_handler.is_empty());
Ok(())
}

View File

@@ -1,7 +1,7 @@
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
@@ -93,7 +93,7 @@ impl ConnCfg {
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: &Self) {
pub fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
@@ -253,6 +253,8 @@ pub struct PostgresConnection {
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
_guage: IntCounterPairGuard,
}
@@ -263,6 +265,7 @@ impl ConnCfg {
&self,
ctx: &mut RequestMonitoring,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
@@ -297,6 +300,7 @@ impl ConnCfg {
stream,
params,
cancel_closure,
aux,
_guage: NUM_DB_CONNECTIONS_GAUGE
.with_label_values(&[ctx.protocol])
.guard(),

View File

@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>,
pub auth_backend: auth::BackendType<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
}

View File

@@ -4,7 +4,10 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::{backend::ComputeUserInfo, IpPattern},
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, ProjectInfoCacheOptions},
@@ -261,6 +264,34 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool,
}
impl NodeInfo {
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(
ctx,
self.allow_self_signed_compute,
self.aux.clone(),
timeout,
)
.await
}
pub fn reuse_settings(&mut self, other: Self) {
self.allow_self_signed_compute = other.allow_self_signed_compute;
self.config.reuse_password(other.config);
}
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
};
}
}
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;

View File

@@ -176,9 +176,7 @@ impl super::Api for Api {
_ctx: &mut RequestMonitoring,
_user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)
.await
self.do_wake_compute().map_ok(Cached::new_uncached).await
}
}

View File

@@ -11,7 +11,7 @@ use crate::{
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, EndpointId, ProjectId, RoleName,
BranchId, DbName, EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -34,9 +34,11 @@ pub struct RequestMonitoring {
project: Option<ProjectId>,
branch: Option<BranchId>,
endpoint_id: Option<EndpointId>,
dbname: Option<DbName>,
user: Option<RoleName>,
application: Option<SmolStr>,
error_kind: Option<ErrorKind>,
pub(crate) auth_method: Option<AuthMethod>,
success: bool,
// extra
@@ -45,6 +47,15 @@ pub struct RequestMonitoring {
pub latency_timer: LatencyTimer,
}
#[derive(Clone, Debug)]
pub enum AuthMethod {
// aka link aka passwordless
Web,
ScramSha256,
ScramSha256Plus,
Cleartext,
}
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
@@ -62,9 +73,11 @@ impl RequestMonitoring {
project: None,
branch: None,
endpoint_id: None,
dbname: None,
user: None,
application: None,
error_kind: None,
auth_method: None,
success: false,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -106,10 +119,18 @@ impl RequestMonitoring {
self.application = app.or_else(|| self.application.clone());
}
pub fn set_dbname(&mut self, dbname: DbName) {
self.dbname = Some(dbname);
}
pub fn set_user(&mut self, user: RoleName) {
self.user = Some(user);
}
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
self.auth_method = Some(auth_method);
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])

View File

@@ -84,8 +84,10 @@ struct RequestData {
username: Option<String>,
application_name: Option<String>,
endpoint_id: Option<String>,
database: Option<String>,
project: Option<String>,
branch: Option<String>,
auth_method: Option<&'static str>,
error: Option<&'static str>,
/// Success is counted if we form a HTTP response with sql rows inside
/// Or if we make it to proxy_pass
@@ -104,8 +106,15 @@ impl From<RequestMonitoring> for RequestData {
username: value.user.as_deref().map(String::from),
application_name: value.application.as_deref().map(String::from),
endpoint_id: value.endpoint_id.as_deref().map(String::from),
database: value.dbname.as_deref().map(String::from),
project: value.project.as_deref().map(String::from),
branch: value.branch.as_deref().map(String::from),
auth_method: value.auth_method.as_ref().map(|x| match x {
super::AuthMethod::Web => "web",
super::AuthMethod::ScramSha256 => "scram_sha_256",
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
super::AuthMethod::Cleartext => "cleartext",
}),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
@@ -431,8 +440,10 @@ mod tests {
application_name: Some("test".to_owned()),
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
auth_method: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
error: None,
@@ -505,15 +516,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
],
);
@@ -543,11 +554,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1028637, 5, 10000),
(1031969, 5, 10000),
(1019900, 5, 10000),
(1020365, 5, 10000),
(1025010, 5, 10000)
(1219459, 5, 10000),
(1225609, 5, 10000),
(1227403, 5, 10000),
(1226765, 5, 10000),
(1218043, 5, 10000)
],
);
@@ -579,11 +590,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1210770, 6, 12000),
(1211036, 6, 12000),
(1210990, 6, 12000),
(1210861, 6, 12000),
(202073, 1, 2000)
(1205106, 5, 10000),
(1204837, 5, 10000),
(1205130, 5, 10000),
(1205118, 5, 10000),
(1205373, 5, 10000)
],
);
@@ -608,15 +619,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
],
);
@@ -653,7 +664,7 @@ mod tests {
// files are smaller than the size threshold, but they took too long to fill so were flushed early
assert_eq!(
file_stats,
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
);
tmpdir.close().unwrap();

View File

@@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError {
}
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
@@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed {
ErrorKind::RateLimit
}
}
impl ReportableError for tokio_postgres::error::Error {
fn get_error_kind(&self) -> ErrorKind {
if self.as_db_error().is_some() {
ErrorKind::Postgres
} else {
ErrorKind::Compute
}
}
}

View File

@@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
.unwrap()
});
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_cancellation_requests_total",
"Number of cancellation requests (per found/not_found).",
&["source", "kind"],
)
.unwrap()
});
#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started
@@ -200,8 +209,9 @@ impl LatencyTimer {
pub fn success(&mut self) {
// stop the stopwatch and record the time that we have accumulated
let start = self.start.take().expect("latency timer should be started");
self.accumulated += start.elapsed();
if let Some(start) = self.start.take() {
self.accumulated += start.elapsed();
}
// success
self.outcome = "success";

View File

@@ -2,6 +2,7 @@
mod tests;
pub mod connect_compute;
mod copy_bidirectional;
pub mod handshake;
pub mod passthrough;
pub mod retry;
@@ -9,7 +10,7 @@ pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancelMap},
cancellation::{self, CancellationHandler},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
@@ -61,6 +62,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -71,7 +73,6 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -79,7 +80,7 @@ pub async fn task_main(
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
let cancellation_handler = Arc::clone(&cancellation_handler);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let session_span = info_span!(
@@ -112,7 +113,7 @@ pub async fn task_main(
let res = handle_client(
config,
&mut ctx,
cancel_map,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
@@ -162,14 +163,14 @@ pub enum ClientMode {
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
fn allow_cleartext(&self) -> bool {
pub fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
@@ -226,7 +227,7 @@ impl ReportableError for ClientRequestError {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancel_map: Arc<CancelMap>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -252,8 +253,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancel_map
.cancel_session(cancel_key_data)
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id)
.await
.map(|()| None)?)
}
@@ -286,7 +287,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
let user = user_info.get_user().to_owned();
let (mut node_info, user_info) = match user_info
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
@@ -305,19 +306,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
};
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
let aux = node_info.aux.clone();
let mut node = connect_to_compute(
ctx,
&TcpMechanism { params: &params },
node_info,
&user_info,
mode.allow_self_signed_compute(config),
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancel_map.get_session();
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
@@ -329,10 +327,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node,
aux,
req: _request_gauge,
conn: _client_gauge,
cancel: session,
}))
}

View File

@@ -1,8 +1,9 @@
use crate::{
auth,
auth::backend::ComputeCredentialKeys,
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError},
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
context::RequestMonitoring,
error::ReportableError,
metrics::NUM_CONNECTION_FAILURES,
proxy::{
retry::{retry_after, ShouldRetry},
@@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
@@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
node_info.invalidate().config
node_info.invalidate()
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
type ConnectError;
type ConnectError: ReportableError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
@@ -49,6 +50,16 @@ pub trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
pub trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
}
pub struct TcpMechanism<'a> {
/// KV-dictionary with PostgreSQL connection params.
pub params: &'a StartupMessageParams,
@@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
node_info.connect(ctx, timeout).await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> {
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
pub async fn connect_to_compute<M: ConnectMechanism>(
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &mut RequestMonitoring,
mechanism: &M,
mut node_info: console::CachedNodeInfo,
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
user_info: &B,
allow_self_signed_compute: bool,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
if let Some(keys) = user_info.get_keys() {
node_info.set_keys(keys);
}
node_info.allow_self_signed_compute = allow_self_signed_compute;
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -108,28 +122,30 @@ where
error!(error = ?err, "could not connect to compute node");
let mut num_retries = 1;
match user_info {
auth::BackendType::Console(api, info) => {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
ctx.latency_timer.cache_miss();
let config = invalidate_cache(node_info);
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
let node_info = if !node_info.cached() {
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if !err.should_retry(num_retries) {
return Err(err.into());
}
// nothing to do?
auth::BackendType::Link(_) => {}
node_info
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
ctx.latency_timer.cache_miss();
let old_node_info = invalidate_cache(node_info);
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);
node_info
};
// now that we have a new node, try connect to it repeatedly.
// this can error for a few reasons, for instance:
// * DNS connection settings haven't quite propagated yet
info!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)

View File

@@ -0,0 +1,256 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
pub(super) async fn copy_bidirectional<A, B>(
a: &mut A,
b: &mut B,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut a_to_b = TransferState::Running(CopyBuffer::new());
let mut b_to_a = TransferState::Running(CopyBuffer::new());
poll_fn(|cx| {
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
// Early termination checks
if let TransferState::Done(_) = a_to_b {
if let TransferState::Running(buf) = &b_to_a {
// Initiate shutdown
b_to_a = TransferState::ShuttingDown(buf.amt);
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
}
}
if let TransferState::Done(_) = b_to_a {
if let TransferState::Running(buf) = &a_to_b {
// Initiate shutdown
a_to_b = TransferState::ShuttingDown(buf.amt);
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let a_to_b = ready!(a_to_b_result);
let b_to_a = ready!(b_to_a_result);
Poll::Ready(Ok((a_to_b, b_to_a)))
})
.await
}
#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
impl CopyBuffer {
pub(super) fn new() -> Self {
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
}
}
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
self.pos = 0;
self.cap = 0;
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}
return Poll::Pending;
}
}
}
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}
// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_early_termination_a_to_d() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
a_mock.write_all(b"hello").await.unwrap();
a_mock.shutdown().await.unwrap();
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
}
#[tokio::test]
async fn test_early_termination_d_to_a() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
d_mock.write_all(b"hello").await.unwrap();
d_mock.shutdown().await.unwrap();
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
}
}

View File

@@ -1,4 +1,5 @@
use crate::{
cancellation,
compute::PostgresConnection,
console::messages::MetricsAuxInfo,
metrics::NUM_BYTES_PROXIED_COUNTER,
@@ -45,7 +46,7 @@ pub async fn proxy_pass(
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
@@ -57,6 +58,7 @@ pub struct ProxyPassthrough<S> {
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
pub cancel: cancellation::Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {

View File

@@ -2,13 +2,19 @@
mod mitm;
use std::time::Duration;
use super::connect_compute::ConnectMechanism;
use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
};
use crate::config::CertResolver;
use crate::console::caches::NodeInfoCache;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
@@ -144,7 +150,7 @@ impl TestAuth for Scram {
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0))
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
.await?
.authenticate()
.await?;
@@ -375,6 +381,7 @@ enum ConnectAction {
struct TestConnectMechanism {
counter: Arc<std::sync::Mutex<usize>>,
sequence: Vec<ConnectAction>,
cache: &'static NodeInfoCache,
}
impl TestConnectMechanism {
@@ -393,6 +400,12 @@ impl TestConnectMechanism {
Self {
counter: Arc::new(std::sync::Mutex::new(0)),
sequence,
cache: Box::leak(Box::new(NodeInfoCache::new(
"test",
1,
Duration::from_secs(100),
false,
))),
}
}
}
@@ -403,6 +416,13 @@ struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
@@ -436,8 +456,14 @@ impl ConnectMechanism for TestConnectMechanism {
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
ConnectAction::Retry => Err(TestConnectError {
retryable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
retryable: false,
kind: ErrorKind::Compute,
}),
x => panic!("expecting action {:?}, connect is called instead", x),
}
}
@@ -451,7 +477,7 @@ impl TestBackend for TestConnectMechanism {
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
ConnectAction::WakeFail => {
let err = console::errors::ApiError::Console {
status: http::StatusCode::FORBIDDEN,
@@ -483,37 +509,41 @@ impl TestBackend for TestConnectMechanism {
}
}
fn helper_create_cached_node_info() -> CachedNodeInfo {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new(),
aux: Default::default(),
allow_self_signed_compute: false,
};
CachedNodeInfo::new_uncached(node)
let (_, node) = cache.insert("key".into(), node);
node
}
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
let cache = helper_create_cached_node_info();
) -> auth::BackendType<'static, ComputeCredentials, &()> {
let user_info = auth::BackendType::Console(
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
},
);
(cache, user_info)
user_info
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap();
mechanism.verify();
@@ -521,11 +551,12 @@ async fn connect_to_compute_success() {
#[tokio::test]
async fn connect_to_compute_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap();
mechanism.verify();
@@ -534,11 +565,12 @@ async fn connect_to_compute_retry() {
/// Test that we don't retry if the error is not retryable.
#[tokio::test]
async fn connect_to_compute_non_retry_1() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap_err();
mechanism.verify();
@@ -547,11 +579,12 @@ async fn connect_to_compute_non_retry_1() {
/// Even for non-retryable errors, we should retry at least once.
#[tokio::test]
async fn connect_to_compute_non_retry_2() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap();
mechanism.verify();
@@ -560,15 +593,16 @@ async fn connect_to_compute_non_retry_2() {
/// Retry for at most `NUM_RETRIES_CONNECT` times.
#[tokio::test]
async fn connect_to_compute_non_retry_3() {
let _ = env_logger::try_init();
assert_eq!(NUM_RETRIES_CONNECT, 16);
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap_err();
mechanism.verify();
@@ -577,11 +611,12 @@ async fn connect_to_compute_non_retry_3() {
/// Should retry wake compute.
#[tokio::test]
async fn wake_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap();
mechanism.verify();
@@ -590,11 +625,12 @@ async fn wake_retry() {
/// Wake failed with a non-retryable error.
#[tokio::test]
async fn wake_non_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -1,9 +1,4 @@
use crate::auth::backend::ComputeUserInfo;
use crate::console::{
errors::WakeComputeError,
provider::{CachedNodeInfo, ConsoleBackend},
Api,
};
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::context::RequestMonitoring;
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
use crate::proxy::retry::retry_after;
@@ -11,17 +6,16 @@ use hyper::StatusCode;
use std::ops::ControlFlow;
use tracing::{error, warn};
use super::connect_compute::ComputeConnectBackend;
use super::retry::ShouldRetry;
/// wake a compute (or retrieve an existing compute session from cache)
pub async fn wake_compute(
pub async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &mut RequestMonitoring,
api: &ConsoleBackend,
info: &ComputeUserInfo,
api: &B,
) -> Result<CachedNodeInfo, WakeComputeError> {
loop {
let wake_res = api.wake_compute(ctx, info).await;
let wake_res = api.wake_compute(ctx).await;
match handle_try_wake(wake_res, *num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");

View File

@@ -4,4 +4,4 @@ mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{EndpointRateLimiter, RateBucketInfo};
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};

View File

@@ -22,6 +22,44 @@ use super::{
RateLimiterConfig,
};
pub struct RedisRateLimiter {
data: Vec<RateBucket>,
info: &'static [RateBucketInfo],
}
impl RedisRateLimiter {
pub fn new(info: &'static [RateBucketInfo]) -> Self {
Self {
data: vec![
RateBucket {
start: Instant::now(),
count: 0,
};
info.len()
],
info,
}
}
/// Check that number of connections is below `max_rps` rps.
pub fn check(&mut self) -> bool {
let now = Instant::now();
let should_allow_request = self
.data
.iter_mut()
.zip(self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now));
if should_allow_request {
// only increment the bucket counts if the request will actually be accepted
self.data.iter_mut().for_each(RateBucket::inc);
}
should_allow_request
}
}
// Simple per-endpoint rate limiter.
//
// Check that number of connections to the endpoint is below `max_rps` rps.

View File

@@ -1 +1,2 @@
pub mod notifications;
pub mod publisher;

View File

@@ -1,38 +1,44 @@
use std::{convert::Infallible, sync::Arc};
use futures::StreamExt;
use pq_proto::CancelKeyData;
use redis::aio::PubSub;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
};
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
struct ConsoleRedisClient {
struct RedisConsumerClient {
client: redis::Client,
}
impl ConsoleRedisClient {
impl RedisConsumerClient {
pub fn new(url: &str) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self { client })
}
async fn try_connect(&self) -> anyhow::Result<PubSub> {
let mut conn = self.client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
conn.subscribe(CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
Ok(conn)
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(tag = "topic", content = "data")]
enum Notification {
pub(crate) enum Notification {
#[serde(
rename = "/allowed_ips_updated",
deserialize_with = "deserialize_json_string"
@@ -45,16 +51,25 @@ enum Notification {
deserialize_with = "deserialize_json_string"
)]
PasswordUpdate { password_update: PasswordUpdate },
#[serde(rename = "/cancel_session")]
Cancel(CancelSession),
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct AllowedIpsUpdate {
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct PasswordUpdate {
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct CancelSession {
pub region_id: Option<String>,
pub cancel_key_data: CancelKeyData,
pub session_id: Uuid,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,
@@ -64,6 +79,88 @@ where
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
struct MessageHandler<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> {
cache: Arc<C>,
cancellation_handler: Arc<H>,
region_id: String,
}
impl<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> MessageHandler<C, H>
{
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
Self {
cache,
cancellation_handler,
region_id,
}
}
pub fn disable_ttl(&self) {
self.cache.disable_ttl();
}
pub fn enable_ttl(&self) {
self.cache.enable_ttl();
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
use Notification::*;
let payload: String = msg.get_payload()?;
tracing::debug!(?payload, "received a message payload");
let msg: Notification = match serde_json::from_str(&payload) {
Ok(msg) => msg,
Err(e) => {
tracing::error!("broken message: {e}");
return Ok(());
}
};
tracing::debug!(?msg, "received a message");
match msg {
Cancel(cancel_session) => {
tracing::Span::current().record(
"session_id",
&tracing::field::display(cancel_session.session_id),
);
if let Some(cancel_region) = cancel_session.region_id {
// If the message is not for this region, ignore it.
if cancel_region != self.region_id {
return Ok(());
}
}
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
match self
.cancellation_handler
.cancel_session_no_publish(cancel_session.cancel_key_data)
.await
{
Ok(()) => {}
Err(e) => {
tracing::error!("failed to cancel session: {e}");
}
}
}
_ => {
invalidate_cache(self.cache.clone(), msg.clone());
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
let cache = self.cache.clone();
tokio::spawn(async move {
tokio::time::sleep(INVALIDATION_LAG).await;
invalidate_cache(cache, msg);
});
}
}
Ok(())
}
}
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
use Notification::*;
match msg {
@@ -74,50 +171,33 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.project_id,
password_update.role_name,
),
Cancel(_) => unreachable!("cancel message should be handled separately"),
}
}
#[tracing::instrument(skip(cache))]
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let payload: String = msg.get_payload()?;
tracing::debug!(?payload, "received a message payload");
let msg: Notification = match serde_json::from_str(&payload) {
Ok(msg) => msg,
Err(e) => {
tracing::error!("broken message: {e}");
return Ok(());
}
};
tracing::debug!(?msg, "received a message");
invalidate_cache(cache.clone(), msg.clone());
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
tokio::spawn(async move {
tokio::time::sleep(INVALIDATION_LAG).await;
invalidate_cache(cache, msg.clone());
});
Ok(())
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
pub async fn task_main<C>(
url: String,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
cache.enable_ttl();
let handler = MessageHandler::new(
cache,
Arc::new(CancellationHandler::new(cancel_map, None)),
region_id,
);
loop {
let redis = ConsoleRedisClient::new(&url)?;
let redis = RedisConsumerClient::new(&url)?;
let conn = match redis.try_connect().await {
Ok(conn) => {
cache.disable_ttl();
handler.disable_ttl();
conn
}
Err(e) => {
@@ -130,7 +210,7 @@ where
};
let mut stream = conn.into_on_message();
while let Some(msg) = stream.next().await {
match handle_message(msg, cache.clone()) {
match handler.handle_message(msg).await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to handle message: {e}, will try to reconnect");
@@ -138,7 +218,7 @@ where
}
}
}
cache.enable_ttl();
handler.enable_ttl();
}
}
@@ -198,6 +278,33 @@ mod tests {
}
);
Ok(())
}
#[test]
fn parse_cancel_session() -> anyhow::Result<()> {
let cancel_key_data = CancelKeyData {
backend_pid: 42,
cancel_key: 41,
};
let uuid = uuid::Uuid::new_v4();
let msg = Notification::Cancel(CancelSession {
cancel_key_data,
region_id: None,
session_id: uuid,
});
let text = serde_json::to_string(&msg)?;
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(msg, result);
let msg = Notification::Cancel(CancelSession {
cancel_key_data,
region_id: Some("region".to_string()),
session_id: uuid,
});
let text = serde_json::to_string(&msg)?;
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(msg, result,);
Ok(())
}
}

View File

@@ -0,0 +1,80 @@
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
pub struct RedisPublisherClient {
client: redis::Client,
publisher: Option<redis::aio::Connection>,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
url: &str,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self {
client,
publisher: None,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
pub async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
self.publisher = None;
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let conn = self
.publisher
.as_mut()
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.get_async_connection().await {
Ok(conn) => {
self.publisher = Some(conn);
}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e.into());
}
}
Ok(())
}
}

View File

@@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancelMap, config::ProxyConfig};
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
@@ -50,6 +50,7 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -115,7 +116,7 @@ pub async fn task_main(
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
@@ -127,9 +128,9 @@ pub async fn task_main(
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
request_handler(
@@ -137,7 +138,7 @@ pub async fn task_main(
config,
backend,
ws_connections,
cancel_map,
cancellation_handler,
session_id,
peer_addr.ip(),
endpoint_rate_limiter,
@@ -205,7 +206,7 @@ async fn request_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -232,7 +233,7 @@ async fn request_handler(
config,
ctx,
websocket,
cancel_map,
cancellation_handler,
host,
endpoint_rate_limiter,
)

View File

@@ -1,10 +1,10 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use tracing::info;
use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
compute,
config::ProxyConfig,
console::{
@@ -15,7 +15,7 @@ use crate::{
proxy::connect_compute::ConnectMechanism,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -27,7 +27,7 @@ impl PoolingBackend {
&self,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<ComputeCredentialKeys, AuthError> {
) -> Result<ComputeCredentials, AuthError> {
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
@@ -49,13 +49,17 @@ impl PoolingBackend {
};
let auth_outcome =
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
match auth_outcome {
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => Ok(key),
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}
}
};
res.map(|key| ComputeCredentials {
info: user_info,
keys: key,
})
}
// Wake up the destination if needed. Code here is a bit involved because
@@ -66,7 +70,7 @@ impl PoolingBackend {
&self,
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
keys: ComputeCredentialKeys,
keys: ComputeCredentials,
force_new: bool,
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if !force_new {
@@ -81,27 +85,9 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
ctx.set_application(Some(APP_NAME));
let backend = self
.config
.auth_backend
.as_ref()
.map(|_| conn_info.user_info.clone());
let mut node_info = backend
.wake_compute(ctx)
.await?
.ok_or(HttpConnError::NoComputeInfo)?;
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
};
ctx.set_project(node_info.aux.clone());
let backend = self.config.auth_backend.as_ref().map(|_| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -109,8 +95,8 @@ impl PoolingBackend {
conn_info,
pool: self.pool.clone(),
},
node_info,
&backend,
false, // do not allow self signed compute for http flow
)
.await
}
@@ -129,8 +115,6 @@ pub enum HttpConnError {
AuthError(#[from] AuthError),
#[error("wake_compute returned error")]
WakeCompute(#[from] WakeComputeError),
#[error("wake_compute returned nothing")]
NoComputeInfo,
}
struct TokioMechanism {

View File

@@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard;
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
use smol_str::SmolStr;
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::{
fmt,
@@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub user_info: ComputeUserInfo,
@@ -379,12 +376,13 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
} else {
info!("pool: reusing connection '{conn_info}'");
client.session.send(ctx.session_id)?;
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
&tracing::field::display(client.inner.get_process_id()),
);
info!("pool: reusing connection '{conn_info}'");
client.session.send(ctx.session_id)?;
ctx.latency_timer.pool_hit();
ctx.latency_timer.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
@@ -577,7 +575,6 @@ pub struct Client<C: ClientInnerExt> {
}
pub struct Discard<'a, C: ClientInnerExt> {
conn_id: uuid::Uuid,
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
@@ -603,14 +600,7 @@ impl<C: ClientInnerExt> Client<C> {
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(
&mut inner.inner,
Discard {
pool,
conn_info,
conn_id: inner.conn_id,
},
)
(&mut inner.inner, Discard { pool, conn_info })
}
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
@@ -625,13 +615,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
}
}
pub fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
}
}
}

View File

@@ -36,6 +36,8 @@ use crate::error::ReportableError;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::DbName;
use crate::RoleName;
use super::backend::PoolingBackend;
@@ -117,6 +119,9 @@ fn get_conn_info(
headers: &HeaderMap,
tls: &TlsConfig,
) -> Result<ConnInfo, ConnInfoError> {
// HTTP only uses cleartext (for now and likely always)
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
@@ -134,7 +139,8 @@ fn get_conn_info(
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
@@ -174,7 +180,7 @@ fn get_conn_info(
Ok(ConnInfo {
user_info,
dbname: dbname.into(),
dbname,
password: match password {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
@@ -300,7 +306,14 @@ pub async fn handle(
Ok(response)
}
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
#[instrument(
name = "sql-over-http",
skip_all,
fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty
)
)]
async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
@@ -354,12 +367,10 @@ async fn handle_inner(
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
let paused = ctx.latency_timer.pause();
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
};
drop(paused);
info!(request_content_length, "request size in bytes");
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
@@ -375,15 +386,20 @@ async fn handle_inner(
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?;
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
};
let authenticate_and_connect = async {
let keys = backend.authenticate(ctx, &conn_info).await?;
backend
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
};
// Run both operations in parallel
@@ -415,6 +431,7 @@ async fn handle_inner(
results
}
Payload::Batch(statements) => {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
@@ -444,6 +461,7 @@ async fn handle_inner(
.await
{
Ok(results) => {
info!("commit");
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
@@ -454,6 +472,7 @@ async fn handle_inner(
results
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
@@ -528,8 +547,10 @@ async fn query_to_json<T: GenericClient>(
raw_output: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
info!("executing query");
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
info!("finished executing query");
// Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too
@@ -564,6 +585,13 @@ async fn query_to_json<T: GenericClient>(
}
.and_then(|s| s.parse::<i64>().ok());
info!(
rows = rows.len(),
?ready,
command_tag,
"finished reading rows"
);
let mut fields = vec![];
let mut columns = vec![];

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation::CancelMap,
cancellation::CancellationHandler,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancel_map: Arc<CancelMap>,
cancellation_handler: Arc<CancellationHandler>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
@@ -141,7 +141,7 @@ pub async fn serve_websocket(
let res = handle_client(
config,
&mut ctx,
cancel_map,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,

View File

@@ -38,7 +38,7 @@ futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
hex = { version = "0.4", features = ["serde"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] }
@@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits", "use_std"] }