mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-06 04:00:37 +00:00
Compare commits
266 Commits
conrad/pro
...
release-48
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c94269c32 | ||
|
|
edc691647d | ||
|
|
855d7b4781 | ||
|
|
c49c9707ce | ||
|
|
2227540a0d | ||
|
|
f1347f2417 | ||
|
|
30b295b017 | ||
|
|
1cef395266 | ||
|
|
78d160f76d | ||
|
|
b9238059d6 | ||
|
|
d0cb4b88c8 | ||
|
|
1ec3e39d4e | ||
|
|
a1a74eef2c | ||
|
|
90e689adda | ||
|
|
f0b2d4b053 | ||
|
|
299d9474c9 | ||
|
|
7234208b36 | ||
|
|
93450f11f5 | ||
|
|
2f0f9edf33 | ||
|
|
d424f2b7c8 | ||
|
|
21315e80bc | ||
|
|
483b66d383 | ||
|
|
aa72a22661 | ||
|
|
5c0264b591 | ||
|
|
9f13277729 | ||
|
|
54aa319805 | ||
|
|
4a227484bf | ||
|
|
2f83f85291 | ||
|
|
d6cfcb0d93 | ||
|
|
392843ad2a | ||
|
|
bd4dae8f4a | ||
|
|
b05fe53cfd | ||
|
|
c13a2f0df1 | ||
|
|
39be366fc5 | ||
|
|
6eda0a3158 | ||
|
|
306c7a1813 | ||
|
|
80be423a58 | ||
|
|
5dcfef82f2 | ||
|
|
e67b8f69c0 | ||
|
|
e546872ab4 | ||
|
|
322ea1cf7c | ||
|
|
3633742de9 | ||
|
|
079d3a37ba | ||
|
|
a46e77b476 | ||
|
|
a92702b01e | ||
|
|
8ff3253f20 | ||
|
|
04b82c92a7 | ||
|
|
e5bf423e68 | ||
|
|
60af392e45 | ||
|
|
661fc41e71 | ||
|
|
702c488f32 | ||
|
|
45c5122754 | ||
|
|
558394f710 | ||
|
|
73b0898608 | ||
|
|
e65be4c2dc | ||
|
|
40087b8164 | ||
|
|
c762b59483 | ||
|
|
5d71601ca9 | ||
|
|
a113c3e433 | ||
|
|
e81fc598f4 | ||
|
|
48b845fa76 | ||
|
|
27096858dc | ||
|
|
4430d0ae7d | ||
|
|
6e183aa0de | ||
|
|
fd6d0b7635 | ||
|
|
3710c32aae | ||
|
|
be83bee49d | ||
|
|
cf28e5922a | ||
|
|
7d384d6953 | ||
|
|
4b3b37b912 | ||
|
|
1d8d200f4d | ||
|
|
0d80d6ce18 | ||
|
|
f653ee039f | ||
|
|
e614a95853 | ||
|
|
850db4cc13 | ||
|
|
8a316b1277 | ||
|
|
4d13bae449 | ||
|
|
49377abd98 | ||
|
|
a6b2f4e54e | ||
|
|
face60d50b | ||
|
|
9768aa27f2 | ||
|
|
96b2e575e1 | ||
|
|
7222777784 | ||
|
|
5469fdede0 | ||
|
|
72aa6b9fdd | ||
|
|
ae0634b7be | ||
|
|
70711f32fa | ||
|
|
52a88af0aa | ||
|
|
b7a43bf817 | ||
|
|
dce91b33a4 | ||
|
|
23ee4f3050 | ||
|
|
46857e8282 | ||
|
|
368ab0ce54 | ||
|
|
a5987eebfd | ||
|
|
6686ede30f | ||
|
|
373c7057cc | ||
|
|
7d6ec16166 | ||
|
|
0e6fdc8a58 | ||
|
|
521438a5c6 | ||
|
|
07d7874bc8 | ||
|
|
1804111a02 | ||
|
|
cd0178efed | ||
|
|
333574be57 | ||
|
|
79a799a143 | ||
|
|
9da06af6c9 | ||
|
|
ce1753d036 | ||
|
|
67db8432b4 | ||
|
|
4e2e44e524 | ||
|
|
ed786104f3 | ||
|
|
84b74f2bd1 | ||
|
|
fec2ad6283 | ||
|
|
98eebd4682 | ||
|
|
2f74287c9b | ||
|
|
aee1bf95e3 | ||
|
|
b9de9d75ff | ||
|
|
7943b709e6 | ||
|
|
d7d066d493 | ||
|
|
e78ac22107 | ||
|
|
76a8f2bb44 | ||
|
|
8d59a8581f | ||
|
|
b1ddd01289 | ||
|
|
6eae4fc9aa | ||
|
|
765455bca2 | ||
|
|
4204960942 | ||
|
|
67345d66ea | ||
|
|
2266ee5971 | ||
|
|
b58445d855 | ||
|
|
36050e7f3d | ||
|
|
33360ed96d | ||
|
|
39a28d1108 | ||
|
|
efa6aa134f | ||
|
|
2c724e56e2 | ||
|
|
feff887c6f | ||
|
|
353d915fcf | ||
|
|
2e38098cbc | ||
|
|
a6fe5ea1ac | ||
|
|
05b0aed0c1 | ||
|
|
cd1705357d | ||
|
|
6bc7561290 | ||
|
|
fbd3ac14b5 | ||
|
|
e437787c8f | ||
|
|
3460dbf90b | ||
|
|
6b89d99677 | ||
|
|
6cc8ea86e4 | ||
|
|
e62a492d6f | ||
|
|
a475cdf642 | ||
|
|
7002c79a47 | ||
|
|
ee6cf357b4 | ||
|
|
e5c2086b5f | ||
|
|
5f1208296a | ||
|
|
88e8e473cd | ||
|
|
b0a77844f6 | ||
|
|
1baf464307 | ||
|
|
e9b8e81cea | ||
|
|
85d6194aa4 | ||
|
|
333a7a68ef | ||
|
|
6aa4e41bee | ||
|
|
840183e51f | ||
|
|
cbccc94b03 | ||
|
|
fce227df22 | ||
|
|
bd787e800f | ||
|
|
4a7704b4a3 | ||
|
|
ff1119da66 | ||
|
|
4c3ba1627b | ||
|
|
1407174fb2 | ||
|
|
ec9dcb1889 | ||
|
|
d11d781afc | ||
|
|
4e44565b71 | ||
|
|
4ed51ad33b | ||
|
|
1c1ebe5537 | ||
|
|
c19cb7f386 | ||
|
|
4b97d31b16 | ||
|
|
923ade3dd7 | ||
|
|
b04e711975 | ||
|
|
afd0a6b39a | ||
|
|
99752286d8 | ||
|
|
15df93363c | ||
|
|
bc0ab741af | ||
|
|
51d9dfeaa3 | ||
|
|
f63cb18155 | ||
|
|
0de603d88e | ||
|
|
240913912a | ||
|
|
91a4ea0de2 | ||
|
|
8608704f49 | ||
|
|
efef68ce99 | ||
|
|
8daefd24da | ||
|
|
46cc8b7982 | ||
|
|
38cd90dd0c | ||
|
|
a51b269f15 | ||
|
|
43bf6d0a0f | ||
|
|
15273a9b66 | ||
|
|
78aca668d0 | ||
|
|
acbf4148ea | ||
|
|
6508540561 | ||
|
|
a41b5244a8 | ||
|
|
2b3189be95 | ||
|
|
248563c595 | ||
|
|
14cd6ca933 | ||
|
|
eb36403e71 | ||
|
|
3c6f779698 | ||
|
|
f67f0c1c11 | ||
|
|
edb02d3299 | ||
|
|
664a69e65b | ||
|
|
478322ebf9 | ||
|
|
802f174072 | ||
|
|
47f9890bae | ||
|
|
262265daad | ||
|
|
300da5b872 | ||
|
|
7b22b5c433 | ||
|
|
ffca97bc1e | ||
|
|
cb356f3259 | ||
|
|
c85374295f | ||
|
|
4992160677 | ||
|
|
bd535b3371 | ||
|
|
d90c5a03af | ||
|
|
2d02cc9079 | ||
|
|
49ad94b99f | ||
|
|
948a217398 | ||
|
|
125381eae7 | ||
|
|
cd01bbc715 | ||
|
|
d8b5e3b88d | ||
|
|
06d25f2186 | ||
|
|
f759b561f3 | ||
|
|
ece0555600 | ||
|
|
73ea0a0b01 | ||
|
|
d8f6d6fd6f | ||
|
|
d24de169a7 | ||
|
|
0816168296 | ||
|
|
277b44d57a | ||
|
|
68c2c3880e | ||
|
|
49da498f65 | ||
|
|
2c76ba3dd7 | ||
|
|
dbe3dc69ad | ||
|
|
8e5bb3ed49 | ||
|
|
ab0be7b8da | ||
|
|
b4c55f5d24 | ||
|
|
ede70d833c | ||
|
|
70c3d18bb0 | ||
|
|
7a491f52c4 | ||
|
|
323c4ecb4f | ||
|
|
3d2466607e | ||
|
|
ed478b39f4 | ||
|
|
91585a558d | ||
|
|
93467eae1f | ||
|
|
f3aac81d19 | ||
|
|
979ad60c19 | ||
|
|
9316cb1b1f | ||
|
|
e7939a527a | ||
|
|
36d26665e1 | ||
|
|
873347f977 | ||
|
|
e814ac16f9 | ||
|
|
ad3055d386 | ||
|
|
94e03eb452 | ||
|
|
380f26ef79 | ||
|
|
3c5b7f59d7 | ||
|
|
fee89f80b5 | ||
|
|
41cce8eaf1 | ||
|
|
f88fe0218d | ||
|
|
cc856eca85 | ||
|
|
cf350c6002 | ||
|
|
0ce6b6a0a3 | ||
|
|
73f247d537 | ||
|
|
960be82183 | ||
|
|
806e5a6c19 | ||
|
|
8d5df07cce | ||
|
|
df7a9d1407 |
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -2247,11 +2247,11 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashlink"
|
name = "hashlink"
|
||||||
version = "0.8.2"
|
version = "0.8.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
|
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"hashbrown 0.13.2",
|
"hashbrown 0.14.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3936,6 +3936,7 @@ dependencies = [
|
|||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres-protocol",
|
"postgres-protocol",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"serde",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ futures-core = "0.3"
|
|||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
git-version = "0.3"
|
git-version = "0.3"
|
||||||
hashbrown = "0.13"
|
hashbrown = "0.13"
|
||||||
hashlink = "0.8.1"
|
hashlink = "0.8.4"
|
||||||
hdrhistogram = "7.5.2"
|
hdrhistogram = "7.5.2"
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
hex-literal = "0.4"
|
hex-literal = "0.4"
|
||||||
|
|||||||
@@ -13,5 +13,6 @@ rand.workspace = true
|
|||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
|
||||||
workspace_hack.workspace = true
|
workspace_hack.workspace = true
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ pub mod framed;
|
|||||||
|
|
||||||
use byteorder::{BigEndian, ReadBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt};
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
|
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
|
||||||
|
|
||||||
// re-export for use in utils pageserver_feedback.rs
|
// re-export for use in utils pageserver_feedback.rs
|
||||||
@@ -123,7 +124,7 @@ impl StartupMessageParams {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||||
pub struct CancelKeyData {
|
pub struct CancelKeyData {
|
||||||
pub backend_pid: i32,
|
pub backend_pid: i32,
|
||||||
pub cancel_key: i32,
|
pub cancel_key: i32,
|
||||||
|
|||||||
@@ -36,9 +36,6 @@ pub enum AuthErrorImpl {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
|
||||||
|
|
||||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Sasl(#[from] crate::sasl::Error),
|
Sasl(#[from] crate::sasl::Error),
|
||||||
@@ -119,7 +116,6 @@ impl UserFacingError for AuthError {
|
|||||||
match self.0.as_ref() {
|
match self.0.as_ref() {
|
||||||
Link(e) => e.to_string_client(),
|
Link(e) => e.to_string_client(),
|
||||||
GetAuthInfo(e) => e.to_string_client(),
|
GetAuthInfo(e) => e.to_string_client(),
|
||||||
WakeCompute(e) => e.to_string_client(),
|
|
||||||
Sasl(e) => e.to_string_client(),
|
Sasl(e) => e.to_string_client(),
|
||||||
AuthFailed(_) => self.to_string(),
|
AuthFailed(_) => self.to_string(),
|
||||||
BadAuthMethod(_) => self.to_string(),
|
BadAuthMethod(_) => self.to_string(),
|
||||||
@@ -139,7 +135,6 @@ impl ReportableError for AuthError {
|
|||||||
match self.0.as_ref() {
|
match self.0.as_ref() {
|
||||||
Link(e) => e.get_error_kind(),
|
Link(e) => e.get_error_kind(),
|
||||||
GetAuthInfo(e) => e.get_error_kind(),
|
GetAuthInfo(e) => e.get_error_kind(),
|
||||||
WakeCompute(e) => e.get_error_kind(),
|
|
||||||
Sasl(e) => e.get_error_kind(),
|
Sasl(e) => e.get_error_kind(),
|
||||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
|
|||||||
use crate::cache::Cached;
|
use crate::cache::Cached;
|
||||||
use crate::console::errors::GetAuthInfoError;
|
use crate::console::errors::GetAuthInfoError;
|
||||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||||
use crate::console::AuthSecret;
|
use crate::console::{AuthSecret, NodeInfo};
|
||||||
use crate::context::RequestMonitoring;
|
use crate::context::RequestMonitoring;
|
||||||
use crate::proxy::wake_compute::wake_compute;
|
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||||
use crate::proxy::NeonOptions;
|
use crate::proxy::NeonOptions;
|
||||||
use crate::stream::Stream;
|
use crate::stream::Stream;
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -26,7 +26,6 @@ use crate::{
|
|||||||
stream, url,
|
stream, url,
|
||||||
};
|
};
|
||||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||||
use futures::TryFutureExt;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -56,11 +55,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
|||||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||||
/// this helps us provide the credentials only to those auth
|
/// this helps us provide the credentials only to those auth
|
||||||
/// backends which require them for the authentication process.
|
/// backends which require them for the authentication process.
|
||||||
pub enum BackendType<'a, T> {
|
pub enum BackendType<'a, T, D> {
|
||||||
/// Cloud API (V2).
|
/// Cloud API (V2).
|
||||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||||
/// Authentication via a web browser.
|
/// Authentication via a web browser.
|
||||||
Link(MaybeOwned<'a, url::ApiUrl>),
|
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait TestBackend: Send + Sync + 'static {
|
pub trait TestBackend: Send + Sync + 'static {
|
||||||
@@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static {
|
|||||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for BackendType<'_, ()> {
|
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
@@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||||
},
|
},
|
||||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> BackendType<'_, T> {
|
impl<T, D> BackendType<'_, T, D> {
|
||||||
/// Very similar to [`std::option::Option::as_ref`].
|
/// Very similar to [`std::option::Option::as_ref`].
|
||||||
/// This helps us pass structured config to async tasks.
|
/// This helps us pass structured config to async tasks.
|
||||||
pub fn as_ref(&self) -> BackendType<'_, &T> {
|
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||||
Link(c) => Link(MaybeOwned::Borrowed(c)),
|
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> BackendType<'a, T> {
|
impl<'a, T, D> BackendType<'a, T, D> {
|
||||||
/// Very similar to [`std::option::Option::map`].
|
/// Very similar to [`std::option::Option::map`].
|
||||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||||
/// a function to a contained value.
|
/// a function to a contained value.
|
||||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
|
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
Console(c, x) => Console(c, f(x)),
|
Console(c, x) => Console(c, f(x)),
|
||||||
Link(c) => Link(c),
|
Link(c, x) => Link(c, x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||||
impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
|
||||||
/// Very similar to [`std::option::Option::transpose`].
|
/// Very similar to [`std::option::Option::transpose`].
|
||||||
/// This is most useful for error handling.
|
/// This is most useful for error handling.
|
||||||
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
|
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
Console(c, x) => x.map(|x| Console(c, x)),
|
Console(c, x) => x.map(|x| Console(c, x)),
|
||||||
Link(c) => Ok(Link(c)),
|
Link(c, x) => Ok(Link(c, x)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ComputeCredentials<T> {
|
pub struct ComputeCredentials {
|
||||||
pub info: ComputeUserInfo,
|
pub info: ComputeUserInfo,
|
||||||
pub keys: T,
|
pub keys: ComputeCredentialKeys,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -153,7 +151,6 @@ impl ComputeUserInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub enum ComputeCredentialKeys {
|
pub enum ComputeCredentialKeys {
|
||||||
#[cfg(any(test, feature = "testing"))]
|
|
||||||
Password(Vec<u8>),
|
Password(Vec<u8>),
|
||||||
AuthKeys(AuthKeys),
|
AuthKeys(AuthKeys),
|
||||||
}
|
}
|
||||||
@@ -188,19 +185,21 @@ async fn auth_quirks(
|
|||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
allow_cleartext: bool,
|
allow_cleartext: bool,
|
||||||
config: &'static AuthenticationConfig,
|
config: &'static AuthenticationConfig,
|
||||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
) -> auth::Result<ComputeCredentials> {
|
||||||
// If there's no project so far, that entails that client doesn't
|
// If there's no project so far, that entails that client doesn't
|
||||||
// support SNI or other means of passing the endpoint (project) name.
|
// support SNI or other means of passing the endpoint (project) name.
|
||||||
// We now expect to see a very specific payload in the place of password.
|
// We now expect to see a very specific payload in the place of password.
|
||||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||||
Err(info) => {
|
Err(info) => {
|
||||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||||
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
||||||
|
let password = match res.keys {
|
||||||
(res.info, Some(res.keys))
|
ComputeCredentialKeys::Password(p) => p,
|
||||||
|
_ => unreachable!("password hack should return a password"),
|
||||||
|
};
|
||||||
|
(res.info, Some(password))
|
||||||
}
|
}
|
||||||
Ok(info) => (info, None),
|
Ok(info) => (info, None),
|
||||||
};
|
};
|
||||||
@@ -254,7 +253,7 @@ async fn authenticate_with_secret(
|
|||||||
unauthenticated_password: Option<Vec<u8>>,
|
unauthenticated_password: Option<Vec<u8>>,
|
||||||
allow_cleartext: bool,
|
allow_cleartext: bool,
|
||||||
config: &'static AuthenticationConfig,
|
config: &'static AuthenticationConfig,
|
||||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
) -> auth::Result<ComputeCredentials> {
|
||||||
if let Some(password) = unauthenticated_password {
|
if let Some(password) = unauthenticated_password {
|
||||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||||
let keys = match auth_outcome {
|
let keys = match auth_outcome {
|
||||||
@@ -276,21 +275,22 @@ async fn authenticate_with_secret(
|
|||||||
// Perform cleartext auth if we're allowed to do that.
|
// Perform cleartext auth if we're allowed to do that.
|
||||||
// Currently, we use it for websocket connections (latency).
|
// Currently, we use it for websocket connections (latency).
|
||||||
if allow_cleartext {
|
if allow_cleartext {
|
||||||
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
|
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||||
|
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
classic::authenticate(ctx, info, client, config, secret).await
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||||
/// Get compute endpoint name from the credentials.
|
/// Get compute endpoint name from the credentials.
|
||||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
|
|
||||||
match self {
|
match self {
|
||||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||||
Link(_) => Some("link".into()),
|
Link(_, _) => Some("link".into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
|||||||
|
|
||||||
match self {
|
match self {
|
||||||
Console(_, user_info) => &user_info.user,
|
Console(_, user_info) => &user_info.user,
|
||||||
Link(_) => "link",
|
Link(_, _) => "link",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
|||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
allow_cleartext: bool,
|
allow_cleartext: bool,
|
||||||
config: &'static AuthenticationConfig,
|
config: &'static AuthenticationConfig,
|
||||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
|
|
||||||
let res = match self {
|
let res = match self {
|
||||||
@@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
|||||||
"performing authentication using the console"
|
"performing authentication using the console"
|
||||||
);
|
);
|
||||||
|
|
||||||
let compute_credentials =
|
let credentials =
|
||||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||||
|
BackendType::Console(api, credentials)
|
||||||
let mut num_retries = 0;
|
|
||||||
let mut node =
|
|
||||||
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
|
|
||||||
|
|
||||||
ctx.set_project(node.aux.clone());
|
|
||||||
|
|
||||||
match compute_credentials.keys {
|
|
||||||
#[cfg(any(test, feature = "testing"))]
|
|
||||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
|
||||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
|
||||||
};
|
|
||||||
|
|
||||||
(node, BackendType::Console(api, compute_credentials.info))
|
|
||||||
}
|
}
|
||||||
// NOTE: this auth backend doesn't use client credentials.
|
// NOTE: this auth backend doesn't use client credentials.
|
||||||
Link(url) => {
|
Link(url, _) => {
|
||||||
info!("performing link authentication");
|
info!("performing link authentication");
|
||||||
|
|
||||||
let node_info = link::authenticate(ctx, &url, client).await?;
|
let info = link::authenticate(ctx, &url, client).await?;
|
||||||
|
|
||||||
(
|
BackendType::Link(url, info)
|
||||||
CachedNodeInfo::new_uncached(node_info),
|
|
||||||
BackendType::Link(url),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendType<'_, ComputeUserInfo> {
|
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||||
pub async fn get_role_secret(
|
pub async fn get_role_secret(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
@@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> {
|
|||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||||
Link(_) => Ok(Cached::new_uncached(None)),
|
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> {
|
|||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
match self {
|
match self {
|
||||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||||
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
/// When applicable, wake the compute node, gaining its connection info in the process.
|
|
||||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
#[async_trait::async_trait]
|
||||||
pub async fn wake_compute(
|
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||||
&self,
|
async fn wake_compute(
|
||||||
ctx: &mut RequestMonitoring,
|
&self,
|
||||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
ctx: &mut RequestMonitoring,
|
||||||
use BackendType::*;
|
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||||
|
use BackendType::*;
|
||||||
match self {
|
|
||||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
match self {
|
||||||
Link(_) => Ok(None),
|
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||||
|
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||||
|
match self {
|
||||||
|
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||||
|
BackendType::Link(_, _) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||||
|
async fn wake_compute(
|
||||||
|
&self,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||||
|
use BackendType::*;
|
||||||
|
|
||||||
|
match self {
|
||||||
|
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||||
|
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||||
|
match self {
|
||||||
|
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||||
|
BackendType::Link(_, _) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use crate::{
|
|||||||
compute,
|
compute,
|
||||||
config::AuthenticationConfig,
|
config::AuthenticationConfig,
|
||||||
console::AuthSecret,
|
console::AuthSecret,
|
||||||
metrics::LatencyTimer,
|
context::RequestMonitoring,
|
||||||
sasl,
|
sasl,
|
||||||
stream::{PqStream, Stream},
|
stream::{PqStream, Stream},
|
||||||
};
|
};
|
||||||
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
|||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
pub(super) async fn authenticate(
|
pub(super) async fn authenticate(
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
creds: ComputeUserInfo,
|
creds: ComputeUserInfo,
|
||||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
config: &'static AuthenticationConfig,
|
config: &'static AuthenticationConfig,
|
||||||
latency_timer: &mut LatencyTimer,
|
|
||||||
secret: AuthSecret,
|
secret: AuthSecret,
|
||||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
) -> auth::Result<ComputeCredentials> {
|
||||||
let flow = AuthFlow::new(client);
|
let flow = AuthFlow::new(client);
|
||||||
let scram_keys = match secret {
|
let scram_keys = match secret {
|
||||||
#[cfg(any(test, feature = "testing"))]
|
#[cfg(any(test, feature = "testing"))]
|
||||||
@@ -27,13 +27,11 @@ pub(super) async fn authenticate(
|
|||||||
}
|
}
|
||||||
AuthSecret::Scram(secret) => {
|
AuthSecret::Scram(secret) => {
|
||||||
info!("auth endpoint chooses SCRAM");
|
info!("auth endpoint chooses SCRAM");
|
||||||
let scram = auth::Scram(&secret);
|
let scram = auth::Scram(&secret, &mut *ctx);
|
||||||
|
|
||||||
let auth_outcome = tokio::time::timeout(
|
let auth_outcome = tokio::time::timeout(
|
||||||
config.scram_protocol_timeout,
|
config.scram_protocol_timeout,
|
||||||
async {
|
async {
|
||||||
// pause the timer while we communicate with the client
|
|
||||||
let _paused = latency_timer.pause();
|
|
||||||
|
|
||||||
flow.begin(scram).await.map_err(|error| {
|
flow.begin(scram).await.map_err(|error| {
|
||||||
warn!(?error, "error sending scram acknowledgement");
|
warn!(?error, "error sending scram acknowledgement");
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use super::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, AuthFlow},
|
auth::{self, AuthFlow},
|
||||||
console::AuthSecret,
|
console::AuthSecret,
|
||||||
metrics::LatencyTimer,
|
context::RequestMonitoring,
|
||||||
sasl,
|
sasl,
|
||||||
stream::{self, Stream},
|
stream::{self, Stream},
|
||||||
};
|
};
|
||||||
@@ -16,15 +16,16 @@ use tracing::{info, warn};
|
|||||||
/// These properties are benefical for serverless JS workers, so we
|
/// These properties are benefical for serverless JS workers, so we
|
||||||
/// use this mechanism for websocket connections.
|
/// use this mechanism for websocket connections.
|
||||||
pub async fn authenticate_cleartext(
|
pub async fn authenticate_cleartext(
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
info: ComputeUserInfo,
|
info: ComputeUserInfo,
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
latency_timer: &mut LatencyTimer,
|
|
||||||
secret: AuthSecret,
|
secret: AuthSecret,
|
||||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
) -> auth::Result<ComputeCredentials> {
|
||||||
warn!("cleartext auth flow override is enabled, proceeding");
|
warn!("cleartext auth flow override is enabled, proceeding");
|
||||||
|
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||||
|
|
||||||
// pause the timer while we communicate with the client
|
// pause the timer while we communicate with the client
|
||||||
let _paused = latency_timer.pause();
|
let _paused = ctx.latency_timer.pause();
|
||||||
|
|
||||||
let auth_outcome = AuthFlow::new(client)
|
let auth_outcome = AuthFlow::new(client)
|
||||||
.begin(auth::CleartextPassword(secret))
|
.begin(auth::CleartextPassword(secret))
|
||||||
@@ -47,14 +48,15 @@ pub async fn authenticate_cleartext(
|
|||||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||||
pub async fn password_hack_no_authentication(
|
pub async fn password_hack_no_authentication(
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
info: ComputeUserInfoNoEndpoint,
|
info: ComputeUserInfoNoEndpoint,
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
latency_timer: &mut LatencyTimer,
|
) -> auth::Result<ComputeCredentials> {
|
||||||
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
|
|
||||||
warn!("project not specified, resorting to the password hack auth flow");
|
warn!("project not specified, resorting to the password hack auth flow");
|
||||||
|
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||||
|
|
||||||
// pause the timer while we communicate with the client
|
// pause the timer while we communicate with the client
|
||||||
let _paused = latency_timer.pause();
|
let _paused = ctx.latency_timer.pause();
|
||||||
|
|
||||||
let payload = AuthFlow::new(client)
|
let payload = AuthFlow::new(client)
|
||||||
.begin(auth::PasswordHack)
|
.begin(auth::PasswordHack)
|
||||||
@@ -71,6 +73,6 @@ pub async fn password_hack_no_authentication(
|
|||||||
options: info.options,
|
options: info.options,
|
||||||
endpoint: payload.endpoint,
|
endpoint: payload.endpoint,
|
||||||
},
|
},
|
||||||
keys: payload.password,
|
keys: ComputeCredentialKeys::Password(payload.password),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ pub(super) async fn authenticate(
|
|||||||
link_uri: &reqwest::Url,
|
link_uri: &reqwest::Url,
|
||||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||||
) -> auth::Result<NodeInfo> {
|
) -> auth::Result<NodeInfo> {
|
||||||
|
ctx.set_auth_method(crate::context::AuthMethod::Web);
|
||||||
|
|
||||||
// registering waiter can fail if we get unlucky with rng.
|
// registering waiter can fail if we get unlucky with rng.
|
||||||
// just try again.
|
// just try again.
|
||||||
let (psql_session_id, waiter) = loop {
|
let (psql_session_id, waiter) = loop {
|
||||||
|
|||||||
@@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint {
|
|||||||
// record the values if we have them
|
// record the values if we have them
|
||||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||||
ctx.set_user(user.clone());
|
ctx.set_user(user.clone());
|
||||||
|
if let Some(dbname) = params.get("database") {
|
||||||
|
ctx.set_dbname(dbname.into());
|
||||||
|
}
|
||||||
|
|
||||||
// Project name might be passed via PG's command-line options.
|
// Project name might be passed via PG's command-line options.
|
||||||
let endpoint_option = params
|
let endpoint_option = params
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::TlsServerEndPoint,
|
config::TlsServerEndPoint,
|
||||||
console::AuthSecret,
|
console::AuthSecret,
|
||||||
|
context::RequestMonitoring,
|
||||||
sasl, scram,
|
sasl, scram,
|
||||||
stream::{PqStream, Stream},
|
stream::{PqStream, Stream},
|
||||||
};
|
};
|
||||||
|
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||||
use std::io;
|
use std::io;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
@@ -23,7 +25,7 @@ pub trait AuthMethod {
|
|||||||
pub struct Begin;
|
pub struct Begin;
|
||||||
|
|
||||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||||
pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
|
||||||
|
|
||||||
impl AuthMethod for Scram<'_> {
|
impl AuthMethod for Scram<'_> {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
@@ -138,6 +140,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
|||||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||||
/// Perform user authentication. Raise an error in case authentication failed.
|
/// Perform user authentication. Raise an error in case authentication failed.
|
||||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||||
|
let Scram(secret, ctx) = self.state;
|
||||||
|
|
||||||
|
// pause the timer while we communicate with the client
|
||||||
|
let _paused = ctx.latency_timer.pause();
|
||||||
|
|
||||||
// Initial client message contains the chosen auth method's name.
|
// Initial client message contains the chosen auth method's name.
|
||||||
let msg = self.stream.read_password_message().await?;
|
let msg = self.stream.read_password_message().await?;
|
||||||
let sasl = sasl::FirstMessage::parse(&msg)
|
let sasl = sasl::FirstMessage::parse(&msg)
|
||||||
@@ -148,9 +155,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
|||||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match sasl.method {
|
||||||
|
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
|
||||||
|
SCRAM_SHA_256_PLUS => {
|
||||||
|
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
info!("client chooses {}", sasl.method);
|
info!("client chooses {}", sasl.method);
|
||||||
|
|
||||||
let secret = self.state.0;
|
|
||||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||||
.authenticate(scram::Exchange::new(
|
.authenticate(scram::Exchange::new(
|
||||||
secret,
|
secret,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use futures::future::Either;
|
use futures::future::Either;
|
||||||
use proxy::auth;
|
use proxy::auth;
|
||||||
use proxy::auth::backend::MaybeOwned;
|
use proxy::auth::backend::MaybeOwned;
|
||||||
|
use proxy::cancellation::CancelMap;
|
||||||
|
use proxy::cancellation::CancellationHandler;
|
||||||
use proxy::config::AuthenticationConfig;
|
use proxy::config::AuthenticationConfig;
|
||||||
use proxy::config::CacheOptions;
|
use proxy::config::CacheOptions;
|
||||||
use proxy::config::HttpConfig;
|
use proxy::config::HttpConfig;
|
||||||
@@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter;
|
|||||||
use proxy::rate_limiter::RateBucketInfo;
|
use proxy::rate_limiter::RateBucketInfo;
|
||||||
use proxy::rate_limiter::RateLimiterConfig;
|
use proxy::rate_limiter::RateLimiterConfig;
|
||||||
use proxy::redis::notifications;
|
use proxy::redis::notifications;
|
||||||
|
use proxy::redis::publisher::RedisPublisherClient;
|
||||||
use proxy::serverless::GlobalConnPoolOptions;
|
use proxy::serverless::GlobalConnPoolOptions;
|
||||||
use proxy::usage_metrics;
|
use proxy::usage_metrics;
|
||||||
|
|
||||||
@@ -22,6 +25,7 @@ use std::net::SocketAddr;
|
|||||||
use std::pin::pin;
|
use std::pin::pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -129,6 +133,9 @@ struct ProxyCliArgs {
|
|||||||
/// Can be given multiple times for different bucket sizes.
|
/// Can be given multiple times for different bucket sizes.
|
||||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||||
|
/// Redis rate limiter max number of requests per second.
|
||||||
|
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||||
|
redis_rps_limit: Vec<RateBucketInfo>,
|
||||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||||
#[clap(long, default_value_t = 100)]
|
#[clap(long, default_value_t = 100)]
|
||||||
initial_limit: usize,
|
initial_limit: usize,
|
||||||
@@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let cancellation_token = CancellationToken::new();
|
let cancellation_token = CancellationToken::new();
|
||||||
|
|
||||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
||||||
|
let cancel_map = CancelMap::default();
|
||||||
|
let redis_publisher = match &args.redis_notifications {
|
||||||
|
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||||
|
url,
|
||||||
|
args.region.clone(),
|
||||||
|
&config.redis_rps_limit,
|
||||||
|
)?))),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||||
|
cancel_map.clone(),
|
||||||
|
redis_publisher,
|
||||||
|
));
|
||||||
|
|
||||||
// client facing tasks. these will exit on error or on cancellation
|
// client facing tasks. these will exit on error or on cancellation
|
||||||
// cancellation returns Ok(())
|
// cancellation returns Ok(())
|
||||||
@@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
proxy_listener,
|
proxy_listener,
|
||||||
cancellation_token.clone(),
|
cancellation_token.clone(),
|
||||||
endpoint_rate_limiter.clone(),
|
endpoint_rate_limiter.clone(),
|
||||||
|
cancellation_handler.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// TODO: rename the argument to something like serverless.
|
// TODO: rename the argument to something like serverless.
|
||||||
@@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
serverless_listener,
|
serverless_listener,
|
||||||
cancellation_token.clone(),
|
cancellation_token.clone(),
|
||||||
endpoint_rate_limiter.clone(),
|
endpoint_rate_limiter.clone(),
|
||||||
|
cancellation_handler.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let cache = api.caches.project_info.clone();
|
let cache = api.caches.project_info.clone();
|
||||||
if let Some(url) = args.redis_notifications {
|
if let Some(url) = args.redis_notifications {
|
||||||
info!("Starting redis notifications listener ({url})");
|
info!("Starting redis notifications listener ({url})");
|
||||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
maintenance_tasks.spawn(notifications::task_main(
|
||||||
|
url.to_owned(),
|
||||||
|
cache.clone(),
|
||||||
|
cancel_map.clone(),
|
||||||
|
args.region.clone(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||||
}
|
}
|
||||||
@@ -383,7 +410,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
}
|
}
|
||||||
AuthBackend::Link => {
|
AuthBackend::Link => {
|
||||||
let url = args.uri.parse()?;
|
let url = args.uri.parse()?;
|
||||||
auth::BackendType::Link(MaybeOwned::Owned(url))
|
auth::BackendType::Link(MaybeOwned::Owned(url), ())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let http_config = HttpConfig {
|
let http_config = HttpConfig {
|
||||||
@@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
|
|
||||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||||
|
let mut redis_rps_limit = args.redis_rps_limit.clone();
|
||||||
|
RateBucketInfo::validate(&mut redis_rps_limit)?;
|
||||||
|
|
||||||
let config = Box::leak(Box::new(ProxyConfig {
|
let config = Box::leak(Box::new(ProxyConfig {
|
||||||
tls_config,
|
tls_config,
|
||||||
@@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
require_client_ip: args.require_client_ip,
|
require_client_ip: args.require_client_ip,
|
||||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||||
endpoint_rps_limit,
|
endpoint_rps_limit,
|
||||||
|
redis_rps_limit,
|
||||||
handshake_timeout: args.handshake_timeout,
|
handshake_timeout: args.handshake_timeout,
|
||||||
// TODO: add this argument
|
// TODO: add this argument
|
||||||
region: args.region.clone(),
|
region: args.region.clone(),
|
||||||
|
|||||||
@@ -1,16 +1,28 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use pq_proto::CancelKeyData;
|
use pq_proto::CancelKeyData;
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tokio_postgres::{CancelToken, NoTls};
|
use tokio_postgres::{CancelToken, NoTls};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::error::ReportableError;
|
use crate::{
|
||||||
|
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
|
||||||
|
redis::publisher::RedisPublisherClient,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||||
|
|
||||||
/// Enables serving `CancelRequest`s.
|
/// Enables serving `CancelRequest`s.
|
||||||
#[derive(Default)]
|
///
|
||||||
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
|
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
|
||||||
|
pub struct CancellationHandler {
|
||||||
|
map: CancelMap,
|
||||||
|
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum CancelError {
|
pub enum CancelError {
|
||||||
@@ -32,15 +44,43 @@ impl ReportableError for CancelError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CancelMap {
|
impl CancellationHandler {
|
||||||
|
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
|
||||||
|
Self { map, redis_client }
|
||||||
|
}
|
||||||
/// Cancel a running query for the corresponding connection.
|
/// Cancel a running query for the corresponding connection.
|
||||||
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
pub async fn cancel_session(
|
||||||
|
&self,
|
||||||
|
key: CancelKeyData,
|
||||||
|
session_id: Uuid,
|
||||||
|
) -> Result<(), CancelError> {
|
||||||
|
let from = "from_client";
|
||||||
// NB: we should immediately release the lock after cloning the token.
|
// NB: we should immediately release the lock after cloning the token.
|
||||||
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
|
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
|
||||||
tracing::warn!("query cancellation key not found: {key}");
|
tracing::warn!("query cancellation key not found: {key}");
|
||||||
|
if let Some(redis_client) = &self.redis_client {
|
||||||
|
NUM_CANCELLATION_REQUESTS
|
||||||
|
.with_label_values(&[from, "not_found"])
|
||||||
|
.inc();
|
||||||
|
info!("publishing cancellation key to Redis");
|
||||||
|
match redis_client.lock().await.try_publish(key, session_id).await {
|
||||||
|
Ok(()) => {
|
||||||
|
info!("cancellation key successfuly published to Redis");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("failed to publish a message: {e}");
|
||||||
|
return Err(CancelError::IO(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
e.to_string(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
NUM_CANCELLATION_REQUESTS
|
||||||
|
.with_label_values(&[from, "found"])
|
||||||
|
.inc();
|
||||||
info!("cancelling query per user's request using key {key}");
|
info!("cancelling query per user's request using key {key}");
|
||||||
cancel_closure.try_cancel_query().await
|
cancel_closure.try_cancel_query().await
|
||||||
}
|
}
|
||||||
@@ -57,7 +97,7 @@ impl CancelMap {
|
|||||||
|
|
||||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||||
// which is why we have to take care not to rewrite an existing key.
|
// which is why we have to take care not to rewrite an existing key.
|
||||||
match self.0.entry(key) {
|
match self.map.entry(key) {
|
||||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||||
e.insert(None);
|
e.insert(None);
|
||||||
@@ -69,18 +109,46 @@ impl CancelMap {
|
|||||||
info!("registered new query cancellation key {key}");
|
info!("registered new query cancellation key {key}");
|
||||||
Session {
|
Session {
|
||||||
key,
|
key,
|
||||||
cancel_map: self,
|
cancellation_handler: self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn contains(&self, session: &Session) -> bool {
|
fn contains(&self, session: &Session) -> bool {
|
||||||
self.0.contains_key(&session.key)
|
self.map.contains_key(&session.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.0.is_empty()
|
self.map.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait NotificationsCancellationHandler {
|
||||||
|
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl NotificationsCancellationHandler for CancellationHandler {
|
||||||
|
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||||
|
let from = "from_redis";
|
||||||
|
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
|
||||||
|
match cancel_closure {
|
||||||
|
Some(cancel_closure) => {
|
||||||
|
NUM_CANCELLATION_REQUESTS
|
||||||
|
.with_label_values(&[from, "found"])
|
||||||
|
.inc();
|
||||||
|
cancel_closure.try_cancel_query().await
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
NUM_CANCELLATION_REQUESTS
|
||||||
|
.with_label_values(&[from, "not_found"])
|
||||||
|
.inc();
|
||||||
|
tracing::warn!("query cancellation key not found: {key}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +183,7 @@ pub struct Session {
|
|||||||
/// The user-facing key identifying this session.
|
/// The user-facing key identifying this session.
|
||||||
key: CancelKeyData,
|
key: CancelKeyData,
|
||||||
/// The [`CancelMap`] this session belongs to.
|
/// The [`CancelMap`] this session belongs to.
|
||||||
cancel_map: Arc<CancelMap>,
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@@ -123,7 +191,9 @@ impl Session {
|
|||||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||||
info!("enabling query cancellation for this session");
|
info!("enabling query cancellation for this session");
|
||||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
self.cancellation_handler
|
||||||
|
.map
|
||||||
|
.insert(self.key, Some(cancel_closure));
|
||||||
|
|
||||||
self.key
|
self.key
|
||||||
}
|
}
|
||||||
@@ -131,7 +201,7 @@ impl Session {
|
|||||||
|
|
||||||
impl Drop for Session {
|
impl Drop for Session {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.cancel_map.0.remove(&self.key);
|
self.cancellation_handler.map.remove(&self.key);
|
||||||
info!("dropped query cancellation key {}", &self.key);
|
info!("dropped query cancellation key {}", &self.key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -142,13 +212,16 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn check_session_drop() -> anyhow::Result<()> {
|
async fn check_session_drop() -> anyhow::Result<()> {
|
||||||
let cancel_map: Arc<CancelMap> = Default::default();
|
let cancellation_handler = Arc::new(CancellationHandler {
|
||||||
|
map: CancelMap::default(),
|
||||||
|
redis_client: None,
|
||||||
|
});
|
||||||
|
|
||||||
let session = cancel_map.clone().get_session();
|
let session = cancellation_handler.clone().get_session();
|
||||||
assert!(cancel_map.contains(&session));
|
assert!(cancellation_handler.contains(&session));
|
||||||
drop(session);
|
drop(session);
|
||||||
// Check that the session has been dropped.
|
// Check that the session has been dropped.
|
||||||
assert!(cancel_map.is_empty());
|
assert!(cancellation_handler.is_empty());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth::parse_endpoint_param,
|
auth::parse_endpoint_param,
|
||||||
cancellation::CancelClosure,
|
cancellation::CancelClosure,
|
||||||
console::errors::WakeComputeError,
|
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
error::{ReportableError, UserFacingError},
|
error::{ReportableError, UserFacingError},
|
||||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||||
@@ -93,7 +93,7 @@ impl ConnCfg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Reuse password or auth keys from the other config.
|
/// Reuse password or auth keys from the other config.
|
||||||
pub fn reuse_password(&mut self, other: &Self) {
|
pub fn reuse_password(&mut self, other: Self) {
|
||||||
if let Some(password) = other.get_password() {
|
if let Some(password) = other.get_password() {
|
||||||
self.password(password);
|
self.password(password);
|
||||||
}
|
}
|
||||||
@@ -253,6 +253,8 @@ pub struct PostgresConnection {
|
|||||||
pub params: std::collections::HashMap<String, String>,
|
pub params: std::collections::HashMap<String, String>,
|
||||||
/// Query cancellation token.
|
/// Query cancellation token.
|
||||||
pub cancel_closure: CancelClosure,
|
pub cancel_closure: CancelClosure,
|
||||||
|
/// Labels for proxy's metrics.
|
||||||
|
pub aux: MetricsAuxInfo,
|
||||||
|
|
||||||
_guage: IntCounterPairGuard,
|
_guage: IntCounterPairGuard,
|
||||||
}
|
}
|
||||||
@@ -263,6 +265,7 @@ impl ConnCfg {
|
|||||||
&self,
|
&self,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
allow_self_signed_compute: bool,
|
allow_self_signed_compute: bool,
|
||||||
|
aux: MetricsAuxInfo,
|
||||||
timeout: Duration,
|
timeout: Duration,
|
||||||
) -> Result<PostgresConnection, ConnectionError> {
|
) -> Result<PostgresConnection, ConnectionError> {
|
||||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||||
@@ -297,6 +300,7 @@ impl ConnCfg {
|
|||||||
stream,
|
stream,
|
||||||
params,
|
params,
|
||||||
cancel_closure,
|
cancel_closure,
|
||||||
|
aux,
|
||||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||||
.with_label_values(&[ctx.protocol])
|
.with_label_values(&[ctx.protocol])
|
||||||
.guard(),
|
.guard(),
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
|
|||||||
|
|
||||||
pub struct ProxyConfig {
|
pub struct ProxyConfig {
|
||||||
pub tls_config: Option<TlsConfig>,
|
pub tls_config: Option<TlsConfig>,
|
||||||
pub auth_backend: auth::BackendType<'static, ()>,
|
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||||
pub metric_collection: Option<MetricCollectionConfig>,
|
pub metric_collection: Option<MetricCollectionConfig>,
|
||||||
pub allow_self_signed_compute: bool,
|
pub allow_self_signed_compute: bool,
|
||||||
pub http_config: HttpConfig,
|
pub http_config: HttpConfig,
|
||||||
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
|
|||||||
pub require_client_ip: bool,
|
pub require_client_ip: bool,
|
||||||
pub disable_ip_check_for_http: bool,
|
pub disable_ip_check_for_http: bool,
|
||||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||||
|
pub redis_rps_limit: Vec<RateBucketInfo>,
|
||||||
pub region: String,
|
pub region: String,
|
||||||
pub handshake_timeout: Duration,
|
pub handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ pub mod neon;
|
|||||||
|
|
||||||
use super::messages::MetricsAuxInfo;
|
use super::messages::MetricsAuxInfo;
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{backend::ComputeUserInfo, IpPattern},
|
auth::{
|
||||||
|
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||||
|
IpPattern,
|
||||||
|
},
|
||||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||||
compute,
|
compute,
|
||||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||||
@@ -261,6 +264,34 @@ pub struct NodeInfo {
|
|||||||
pub allow_self_signed_compute: bool,
|
pub allow_self_signed_compute: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl NodeInfo {
|
||||||
|
pub async fn connect(
|
||||||
|
&self,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||||
|
self.config
|
||||||
|
.connect(
|
||||||
|
ctx,
|
||||||
|
self.allow_self_signed_compute,
|
||||||
|
self.aux.clone(),
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
pub fn reuse_settings(&mut self, other: Self) {
|
||||||
|
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||||
|
self.config.reuse_password(other.config);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||||
|
match keys {
|
||||||
|
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||||
|
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||||
|
|||||||
@@ -176,9 +176,7 @@ impl super::Api for Api {
|
|||||||
_ctx: &mut RequestMonitoring,
|
_ctx: &mut RequestMonitoring,
|
||||||
_user_info: &ComputeUserInfo,
|
_user_info: &ComputeUserInfo,
|
||||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||||
self.do_wake_compute()
|
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||||
.map_ok(CachedNodeInfo::new_uncached)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use crate::{
|
|||||||
console::messages::MetricsAuxInfo,
|
console::messages::MetricsAuxInfo,
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
||||||
BranchId, EndpointId, ProjectId, RoleName,
|
BranchId, DbName, EndpointId, ProjectId, RoleName,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod parquet;
|
pub mod parquet;
|
||||||
@@ -34,9 +34,11 @@ pub struct RequestMonitoring {
|
|||||||
project: Option<ProjectId>,
|
project: Option<ProjectId>,
|
||||||
branch: Option<BranchId>,
|
branch: Option<BranchId>,
|
||||||
endpoint_id: Option<EndpointId>,
|
endpoint_id: Option<EndpointId>,
|
||||||
|
dbname: Option<DbName>,
|
||||||
user: Option<RoleName>,
|
user: Option<RoleName>,
|
||||||
application: Option<SmolStr>,
|
application: Option<SmolStr>,
|
||||||
error_kind: Option<ErrorKind>,
|
error_kind: Option<ErrorKind>,
|
||||||
|
pub(crate) auth_method: Option<AuthMethod>,
|
||||||
success: bool,
|
success: bool,
|
||||||
|
|
||||||
// extra
|
// extra
|
||||||
@@ -45,6 +47,15 @@ pub struct RequestMonitoring {
|
|||||||
pub latency_timer: LatencyTimer,
|
pub latency_timer: LatencyTimer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum AuthMethod {
|
||||||
|
// aka link aka passwordless
|
||||||
|
Web,
|
||||||
|
ScramSha256,
|
||||||
|
ScramSha256Plus,
|
||||||
|
Cleartext,
|
||||||
|
}
|
||||||
|
|
||||||
impl RequestMonitoring {
|
impl RequestMonitoring {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
@@ -62,9 +73,11 @@ impl RequestMonitoring {
|
|||||||
project: None,
|
project: None,
|
||||||
branch: None,
|
branch: None,
|
||||||
endpoint_id: None,
|
endpoint_id: None,
|
||||||
|
dbname: None,
|
||||||
user: None,
|
user: None,
|
||||||
application: None,
|
application: None,
|
||||||
error_kind: None,
|
error_kind: None,
|
||||||
|
auth_method: None,
|
||||||
success: false,
|
success: false,
|
||||||
|
|
||||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||||
@@ -106,10 +119,18 @@ impl RequestMonitoring {
|
|||||||
self.application = app.or_else(|| self.application.clone());
|
self.application = app.or_else(|| self.application.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_dbname(&mut self, dbname: DbName) {
|
||||||
|
self.dbname = Some(dbname);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_user(&mut self, user: RoleName) {
|
pub fn set_user(&mut self, user: RoleName) {
|
||||||
self.user = Some(user);
|
self.user = Some(user);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
|
||||||
|
self.auth_method = Some(auth_method);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
||||||
ERROR_BY_KIND
|
ERROR_BY_KIND
|
||||||
.with_label_values(&[kind.to_metric_label()])
|
.with_label_values(&[kind.to_metric_label()])
|
||||||
|
|||||||
@@ -84,8 +84,10 @@ struct RequestData {
|
|||||||
username: Option<String>,
|
username: Option<String>,
|
||||||
application_name: Option<String>,
|
application_name: Option<String>,
|
||||||
endpoint_id: Option<String>,
|
endpoint_id: Option<String>,
|
||||||
|
database: Option<String>,
|
||||||
project: Option<String>,
|
project: Option<String>,
|
||||||
branch: Option<String>,
|
branch: Option<String>,
|
||||||
|
auth_method: Option<&'static str>,
|
||||||
error: Option<&'static str>,
|
error: Option<&'static str>,
|
||||||
/// Success is counted if we form a HTTP response with sql rows inside
|
/// Success is counted if we form a HTTP response with sql rows inside
|
||||||
/// Or if we make it to proxy_pass
|
/// Or if we make it to proxy_pass
|
||||||
@@ -104,8 +106,15 @@ impl From<RequestMonitoring> for RequestData {
|
|||||||
username: value.user.as_deref().map(String::from),
|
username: value.user.as_deref().map(String::from),
|
||||||
application_name: value.application.as_deref().map(String::from),
|
application_name: value.application.as_deref().map(String::from),
|
||||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||||
|
database: value.dbname.as_deref().map(String::from),
|
||||||
project: value.project.as_deref().map(String::from),
|
project: value.project.as_deref().map(String::from),
|
||||||
branch: value.branch.as_deref().map(String::from),
|
branch: value.branch.as_deref().map(String::from),
|
||||||
|
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||||
|
super::AuthMethod::Web => "web",
|
||||||
|
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||||
|
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
|
||||||
|
super::AuthMethod::Cleartext => "cleartext",
|
||||||
|
}),
|
||||||
protocol: value.protocol,
|
protocol: value.protocol,
|
||||||
region: value.region,
|
region: value.region,
|
||||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||||
@@ -431,8 +440,10 @@ mod tests {
|
|||||||
application_name: Some("test".to_owned()),
|
application_name: Some("test".to_owned()),
|
||||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||||
|
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||||
|
auth_method: None,
|
||||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||||
region: "us-east-1",
|
region: "us-east-1",
|
||||||
error: None,
|
error: None,
|
||||||
@@ -505,15 +516,15 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
file_stats,
|
file_stats,
|
||||||
[
|
[
|
||||||
(1087635, 3, 6000),
|
(1313727, 3, 6000),
|
||||||
(1087288, 3, 6000),
|
(1313720, 3, 6000),
|
||||||
(1087444, 3, 6000),
|
(1313780, 3, 6000),
|
||||||
(1087572, 3, 6000),
|
(1313737, 3, 6000),
|
||||||
(1087468, 3, 6000),
|
(1313867, 3, 6000),
|
||||||
(1087500, 3, 6000),
|
(1313709, 3, 6000),
|
||||||
(1087533, 3, 6000),
|
(1313501, 3, 6000),
|
||||||
(1087566, 3, 6000),
|
(1313737, 3, 6000),
|
||||||
(362671, 1, 2000)
|
(438118, 1, 2000)
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -543,11 +554,11 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
file_stats,
|
file_stats,
|
||||||
[
|
[
|
||||||
(1028637, 5, 10000),
|
(1219459, 5, 10000),
|
||||||
(1031969, 5, 10000),
|
(1225609, 5, 10000),
|
||||||
(1019900, 5, 10000),
|
(1227403, 5, 10000),
|
||||||
(1020365, 5, 10000),
|
(1226765, 5, 10000),
|
||||||
(1025010, 5, 10000)
|
(1218043, 5, 10000)
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -579,11 +590,11 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
file_stats,
|
file_stats,
|
||||||
[
|
[
|
||||||
(1210770, 6, 12000),
|
(1205106, 5, 10000),
|
||||||
(1211036, 6, 12000),
|
(1204837, 5, 10000),
|
||||||
(1210990, 6, 12000),
|
(1205130, 5, 10000),
|
||||||
(1210861, 6, 12000),
|
(1205118, 5, 10000),
|
||||||
(202073, 1, 2000)
|
(1205373, 5, 10000)
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -608,15 +619,15 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
file_stats,
|
file_stats,
|
||||||
[
|
[
|
||||||
(1087635, 3, 6000),
|
(1313727, 3, 6000),
|
||||||
(1087288, 3, 6000),
|
(1313720, 3, 6000),
|
||||||
(1087444, 3, 6000),
|
(1313780, 3, 6000),
|
||||||
(1087572, 3, 6000),
|
(1313737, 3, 6000),
|
||||||
(1087468, 3, 6000),
|
(1313867, 3, 6000),
|
||||||
(1087500, 3, 6000),
|
(1313709, 3, 6000),
|
||||||
(1087533, 3, 6000),
|
(1313501, 3, 6000),
|
||||||
(1087566, 3, 6000),
|
(1313737, 3, 6000),
|
||||||
(362671, 1, 2000)
|
(438118, 1, 2000)
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -653,7 +664,7 @@ mod tests {
|
|||||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
file_stats,
|
file_stats,
|
||||||
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
|
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
|
||||||
);
|
);
|
||||||
|
|
||||||
tmpdir.close().unwrap();
|
tmpdir.close().unwrap();
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||||
pub enum ErrorKind {
|
pub enum ErrorKind {
|
||||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||||
User,
|
User,
|
||||||
@@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed {
|
|||||||
ErrorKind::RateLimit
|
ErrorKind::RateLimit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for tokio_postgres::error::Error {
|
||||||
|
fn get_error_kind(&self) -> ErrorKind {
|
||||||
|
if self.as_db_error().is_some() {
|
||||||
|
ErrorKind::Postgres
|
||||||
|
} else {
|
||||||
|
ErrorKind::Compute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||||
|
register_int_counter_vec!(
|
||||||
|
"proxy_cancellation_requests_total",
|
||||||
|
"Number of cancellation requests (per found/not_found).",
|
||||||
|
&["source", "kind"],
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct LatencyTimer {
|
pub struct LatencyTimer {
|
||||||
// time since the stopwatch was started
|
// time since the stopwatch was started
|
||||||
@@ -200,8 +209,9 @@ impl LatencyTimer {
|
|||||||
|
|
||||||
pub fn success(&mut self) {
|
pub fn success(&mut self) {
|
||||||
// stop the stopwatch and record the time that we have accumulated
|
// stop the stopwatch and record the time that we have accumulated
|
||||||
let start = self.start.take().expect("latency timer should be started");
|
if let Some(start) = self.start.take() {
|
||||||
self.accumulated += start.elapsed();
|
self.accumulated += start.elapsed();
|
||||||
|
}
|
||||||
|
|
||||||
// success
|
// success
|
||||||
self.outcome = "success";
|
self.outcome = "success";
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
pub mod connect_compute;
|
pub mod connect_compute;
|
||||||
|
mod copy_bidirectional;
|
||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
pub mod passthrough;
|
pub mod passthrough;
|
||||||
pub mod retry;
|
pub mod retry;
|
||||||
@@ -9,7 +10,7 @@ pub mod wake_compute;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
auth,
|
||||||
cancellation::{self, CancelMap},
|
cancellation::{self, CancellationHandler},
|
||||||
compute,
|
compute,
|
||||||
config::{ProxyConfig, TlsConfig},
|
config::{ProxyConfig, TlsConfig},
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
@@ -61,6 +62,7 @@ pub async fn task_main(
|
|||||||
listener: tokio::net::TcpListener,
|
listener: tokio::net::TcpListener,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
scopeguard::defer! {
|
scopeguard::defer! {
|
||||||
info!("proxy has shut down");
|
info!("proxy has shut down");
|
||||||
@@ -71,7 +73,6 @@ pub async fn task_main(
|
|||||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||||
|
|
||||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||||
let cancel_map = Arc::new(CancelMap::default());
|
|
||||||
|
|
||||||
while let Some(accept_result) =
|
while let Some(accept_result) =
|
||||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||||
@@ -79,7 +80,7 @@ pub async fn task_main(
|
|||||||
let (socket, peer_addr) = accept_result?;
|
let (socket, peer_addr) = accept_result?;
|
||||||
|
|
||||||
let session_id = uuid::Uuid::new_v4();
|
let session_id = uuid::Uuid::new_v4();
|
||||||
let cancel_map = Arc::clone(&cancel_map);
|
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||||
|
|
||||||
let session_span = info_span!(
|
let session_span = info_span!(
|
||||||
@@ -112,7 +113,7 @@ pub async fn task_main(
|
|||||||
let res = handle_client(
|
let res = handle_client(
|
||||||
config,
|
config,
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
cancel_map,
|
cancellation_handler,
|
||||||
socket,
|
socket,
|
||||||
ClientMode::Tcp,
|
ClientMode::Tcp,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
@@ -162,14 +163,14 @@ pub enum ClientMode {
|
|||||||
|
|
||||||
/// Abstracts the logic of handling TCP vs WS clients
|
/// Abstracts the logic of handling TCP vs WS clients
|
||||||
impl ClientMode {
|
impl ClientMode {
|
||||||
fn allow_cleartext(&self) -> bool {
|
pub fn allow_cleartext(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
ClientMode::Tcp => false,
|
ClientMode::Tcp => false,
|
||||||
ClientMode::Websockets { .. } => true,
|
ClientMode::Websockets { .. } => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||||
match self {
|
match self {
|
||||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||||
ClientMode::Websockets { .. } => false,
|
ClientMode::Websockets { .. } => false,
|
||||||
@@ -226,7 +227,7 @@ impl ReportableError for ClientRequestError {
|
|||||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
cancel_map: Arc<CancelMap>,
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
stream: S,
|
stream: S,
|
||||||
mode: ClientMode,
|
mode: ClientMode,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
@@ -252,8 +253,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||||
HandshakeData::Startup(stream, params) => (stream, params),
|
HandshakeData::Startup(stream, params) => (stream, params),
|
||||||
HandshakeData::Cancel(cancel_key_data) => {
|
HandshakeData::Cancel(cancel_key_data) => {
|
||||||
return Ok(cancel_map
|
return Ok(cancellation_handler
|
||||||
.cancel_session(cancel_key_data)
|
.cancel_session(cancel_key_data, ctx.session_id)
|
||||||
.await
|
.await
|
||||||
.map(|()| None)?)
|
.map(|()| None)?)
|
||||||
}
|
}
|
||||||
@@ -286,7 +287,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let user = user_info.get_user().to_owned();
|
let user = user_info.get_user().to_owned();
|
||||||
let (mut node_info, user_info) = match user_info
|
let user_info = match user_info
|
||||||
.authenticate(
|
.authenticate(
|
||||||
ctx,
|
ctx,
|
||||||
&mut stream,
|
&mut stream,
|
||||||
@@ -305,19 +306,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
|
||||||
|
|
||||||
let aux = node_info.aux.clone();
|
|
||||||
let mut node = connect_to_compute(
|
let mut node = connect_to_compute(
|
||||||
ctx,
|
ctx,
|
||||||
&TcpMechanism { params: ¶ms },
|
&TcpMechanism { params: ¶ms },
|
||||||
node_info,
|
|
||||||
&user_info,
|
&user_info,
|
||||||
|
mode.allow_self_signed_compute(config),
|
||||||
)
|
)
|
||||||
.or_else(|e| stream.throw_error(e))
|
.or_else(|e| stream.throw_error(e))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let session = cancel_map.get_session();
|
let session = cancellation_handler.get_session();
|
||||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||||
|
|
||||||
// Before proxy passing, forward to compute whatever data is left in the
|
// Before proxy passing, forward to compute whatever data is left in the
|
||||||
@@ -329,10 +327,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
|
|
||||||
Ok(Some(ProxyPassthrough {
|
Ok(Some(ProxyPassthrough {
|
||||||
client: stream,
|
client: stream,
|
||||||
|
aux: node.aux.clone(),
|
||||||
compute: node,
|
compute: node,
|
||||||
aux,
|
|
||||||
req: _request_gauge,
|
req: _request_gauge,
|
||||||
conn: _client_gauge,
|
conn: _client_gauge,
|
||||||
|
cancel: session,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
auth::backend::ComputeCredentialKeys,
|
||||||
compute::{self, PostgresConnection},
|
compute::{self, PostgresConnection},
|
||||||
console::{self, errors::WakeComputeError},
|
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
|
error::ReportableError,
|
||||||
metrics::NUM_CONNECTION_FAILURES,
|
metrics::NUM_CONNECTION_FAILURES,
|
||||||
proxy::{
|
proxy::{
|
||||||
retry::{retry_after, ShouldRetry},
|
retry::{retry_after, ShouldRetry},
|
||||||
@@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
|||||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
|
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||||
let is_cached = node_info.cached();
|
let is_cached = node_info.cached();
|
||||||
if is_cached {
|
if is_cached {
|
||||||
warn!("invalidating stalled compute node info cache entry");
|
warn!("invalidating stalled compute node info cache entry");
|
||||||
@@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
|
|||||||
};
|
};
|
||||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||||
|
|
||||||
node_info.invalidate().config
|
node_info.invalidate()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ConnectMechanism {
|
pub trait ConnectMechanism {
|
||||||
type Connection;
|
type Connection;
|
||||||
type ConnectError;
|
type ConnectError: ReportableError;
|
||||||
type Error: From<Self::ConnectError>;
|
type Error: From<Self::ConnectError>;
|
||||||
async fn connect_once(
|
async fn connect_once(
|
||||||
&self,
|
&self,
|
||||||
@@ -49,6 +50,16 @@ pub trait ConnectMechanism {
|
|||||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ComputeConnectBackend {
|
||||||
|
async fn wake_compute(
|
||||||
|
&self,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||||
|
|
||||||
|
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TcpMechanism<'a> {
|
pub struct TcpMechanism<'a> {
|
||||||
/// KV-dictionary with PostgreSQL connection params.
|
/// KV-dictionary with PostgreSQL connection params.
|
||||||
pub params: &'a StartupMessageParams,
|
pub params: &'a StartupMessageParams,
|
||||||
@@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
|||||||
node_info: &console::CachedNodeInfo,
|
node_info: &console::CachedNodeInfo,
|
||||||
timeout: time::Duration,
|
timeout: time::Duration,
|
||||||
) -> Result<PostgresConnection, Self::Error> {
|
) -> Result<PostgresConnection, Self::Error> {
|
||||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
node_info.connect(ctx, timeout).await
|
||||||
node_info
|
|
||||||
.config
|
|
||||||
.connect(ctx, allow_self_signed_compute, timeout)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||||
@@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
|||||||
/// Try to connect to the compute node, retrying if necessary.
|
/// Try to connect to the compute node, retrying if necessary.
|
||||||
/// This function might update `node_info`, so we take it by `&mut`.
|
/// This function might update `node_info`, so we take it by `&mut`.
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
mechanism: &M,
|
mechanism: &M,
|
||||||
mut node_info: console::CachedNodeInfo,
|
user_info: &B,
|
||||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
allow_self_signed_compute: bool,
|
||||||
) -> Result<M::Connection, M::Error>
|
) -> Result<M::Connection, M::Error>
|
||||||
where
|
where
|
||||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||||
M::Error: From<WakeComputeError>,
|
M::Error: From<WakeComputeError>,
|
||||||
{
|
{
|
||||||
|
let mut num_retries = 0;
|
||||||
|
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||||
|
if let Some(keys) = user_info.get_keys() {
|
||||||
|
node_info.set_keys(keys);
|
||||||
|
}
|
||||||
|
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||||
|
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||||
mechanism.update_connect_config(&mut node_info.config);
|
mechanism.update_connect_config(&mut node_info.config);
|
||||||
|
|
||||||
// try once
|
// try once
|
||||||
@@ -108,28 +122,30 @@ where
|
|||||||
|
|
||||||
error!(error = ?err, "could not connect to compute node");
|
error!(error = ?err, "could not connect to compute node");
|
||||||
|
|
||||||
let mut num_retries = 1;
|
let node_info = if !node_info.cached() {
|
||||||
|
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
|
||||||
match user_info {
|
// Do not need to retrieve a new node_info, just return the old one.
|
||||||
auth::BackendType::Console(api, info) => {
|
if !err.should_retry(num_retries) {
|
||||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
return Err(err.into());
|
||||||
info!("compute node's state has likely changed; requesting a wake-up");
|
|
||||||
|
|
||||||
ctx.latency_timer.cache_miss();
|
|
||||||
let config = invalidate_cache(node_info);
|
|
||||||
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
|
|
||||||
|
|
||||||
node_info.config.reuse_password(&config);
|
|
||||||
mechanism.update_connect_config(&mut node_info.config);
|
|
||||||
}
|
}
|
||||||
// nothing to do?
|
node_info
|
||||||
auth::BackendType::Link(_) => {}
|
} else {
|
||||||
|
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||||
|
info!("compute node's state has likely changed; requesting a wake-up");
|
||||||
|
ctx.latency_timer.cache_miss();
|
||||||
|
let old_node_info = invalidate_cache(node_info);
|
||||||
|
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||||
|
node_info.reuse_settings(old_node_info);
|
||||||
|
|
||||||
|
mechanism.update_connect_config(&mut node_info.config);
|
||||||
|
node_info
|
||||||
};
|
};
|
||||||
|
|
||||||
// now that we have a new node, try connect to it repeatedly.
|
// now that we have a new node, try connect to it repeatedly.
|
||||||
// this can error for a few reasons, for instance:
|
// this can error for a few reasons, for instance:
|
||||||
// * DNS connection settings haven't quite propagated yet
|
// * DNS connection settings haven't quite propagated yet
|
||||||
info!("wake_compute success. attempting to connect");
|
info!("wake_compute success. attempting to connect");
|
||||||
|
num_retries = 1;
|
||||||
loop {
|
loop {
|
||||||
match mechanism
|
match mechanism
|
||||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||||
|
|||||||
256
proxy/src/proxy/copy_bidirectional.rs
Normal file
256
proxy/src/proxy/copy_bidirectional.rs
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
|
||||||
|
use std::future::poll_fn;
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{ready, Context, Poll};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum TransferState {
|
||||||
|
Running(CopyBuffer),
|
||||||
|
ShuttingDown(u64),
|
||||||
|
Done(u64),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transfer_one_direction<A, B>(
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
state: &mut TransferState,
|
||||||
|
r: &mut A,
|
||||||
|
w: &mut B,
|
||||||
|
) -> Poll<io::Result<u64>>
|
||||||
|
where
|
||||||
|
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||||
|
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||||
|
{
|
||||||
|
let mut r = Pin::new(r);
|
||||||
|
let mut w = Pin::new(w);
|
||||||
|
loop {
|
||||||
|
match state {
|
||||||
|
TransferState::Running(buf) => {
|
||||||
|
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||||
|
*state = TransferState::ShuttingDown(count);
|
||||||
|
}
|
||||||
|
TransferState::ShuttingDown(count) => {
|
||||||
|
ready!(w.as_mut().poll_shutdown(cx))?;
|
||||||
|
*state = TransferState::Done(*count);
|
||||||
|
}
|
||||||
|
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn copy_bidirectional<A, B>(
|
||||||
|
a: &mut A,
|
||||||
|
b: &mut B,
|
||||||
|
) -> Result<(u64, u64), std::io::Error>
|
||||||
|
where
|
||||||
|
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||||
|
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||||
|
{
|
||||||
|
let mut a_to_b = TransferState::Running(CopyBuffer::new());
|
||||||
|
let mut b_to_a = TransferState::Running(CopyBuffer::new());
|
||||||
|
|
||||||
|
poll_fn(|cx| {
|
||||||
|
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||||
|
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||||
|
|
||||||
|
// Early termination checks
|
||||||
|
if let TransferState::Done(_) = a_to_b {
|
||||||
|
if let TransferState::Running(buf) = &b_to_a {
|
||||||
|
// Initiate shutdown
|
||||||
|
b_to_a = TransferState::ShuttingDown(buf.amt);
|
||||||
|
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let TransferState::Done(_) = b_to_a {
|
||||||
|
if let TransferState::Running(buf) = &a_to_b {
|
||||||
|
// Initiate shutdown
|
||||||
|
a_to_b = TransferState::ShuttingDown(buf.amt);
|
||||||
|
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is not a problem if ready! returns early ... (comment remains the same)
|
||||||
|
let a_to_b = ready!(a_to_b_result);
|
||||||
|
let b_to_a = ready!(b_to_a_result);
|
||||||
|
|
||||||
|
Poll::Ready(Ok((a_to_b, b_to_a)))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) struct CopyBuffer {
|
||||||
|
read_done: bool,
|
||||||
|
need_flush: bool,
|
||||||
|
pos: usize,
|
||||||
|
cap: usize,
|
||||||
|
amt: u64,
|
||||||
|
buf: Box<[u8]>,
|
||||||
|
}
|
||||||
|
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
|
||||||
|
|
||||||
|
impl CopyBuffer {
|
||||||
|
pub(super) fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
read_done: false,
|
||||||
|
need_flush: false,
|
||||||
|
pos: 0,
|
||||||
|
cap: 0,
|
||||||
|
amt: 0,
|
||||||
|
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_fill_buf<R>(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
reader: Pin<&mut R>,
|
||||||
|
) -> Poll<io::Result<()>>
|
||||||
|
where
|
||||||
|
R: AsyncRead + ?Sized,
|
||||||
|
{
|
||||||
|
let me = &mut *self;
|
||||||
|
let mut buf = ReadBuf::new(&mut me.buf);
|
||||||
|
buf.set_filled(me.cap);
|
||||||
|
|
||||||
|
let res = reader.poll_read(cx, &mut buf);
|
||||||
|
if let Poll::Ready(Ok(())) = res {
|
||||||
|
let filled_len = buf.filled().len();
|
||||||
|
me.read_done = me.cap == filled_len;
|
||||||
|
me.cap = filled_len;
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<R, W>(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
mut reader: Pin<&mut R>,
|
||||||
|
mut writer: Pin<&mut W>,
|
||||||
|
) -> Poll<io::Result<usize>>
|
||||||
|
where
|
||||||
|
R: AsyncRead + ?Sized,
|
||||||
|
W: AsyncWrite + ?Sized,
|
||||||
|
{
|
||||||
|
let me = &mut *self;
|
||||||
|
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
|
||||||
|
Poll::Pending => {
|
||||||
|
// Top up the buffer towards full if we can read a bit more
|
||||||
|
// data - this should improve the chances of a large write
|
||||||
|
if !me.read_done && me.cap < me.buf.len() {
|
||||||
|
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
|
||||||
|
}
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
res => res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn poll_copy<R, W>(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
mut reader: Pin<&mut R>,
|
||||||
|
mut writer: Pin<&mut W>,
|
||||||
|
) -> Poll<io::Result<u64>>
|
||||||
|
where
|
||||||
|
R: AsyncRead + ?Sized,
|
||||||
|
W: AsyncWrite + ?Sized,
|
||||||
|
{
|
||||||
|
loop {
|
||||||
|
// If our buffer is empty, then we need to read some data to
|
||||||
|
// continue.
|
||||||
|
if self.pos == self.cap && !self.read_done {
|
||||||
|
self.pos = 0;
|
||||||
|
self.cap = 0;
|
||||||
|
|
||||||
|
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||||
|
Poll::Ready(Ok(())) => (),
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => {
|
||||||
|
// Try flushing when the reader has no progress to avoid deadlock
|
||||||
|
// when the reader depends on buffered writer.
|
||||||
|
if self.need_flush {
|
||||||
|
ready!(writer.as_mut().poll_flush(cx))?;
|
||||||
|
self.need_flush = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If our buffer has some data, let's write it out!
|
||||||
|
while self.pos < self.cap {
|
||||||
|
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||||
|
if i == 0 {
|
||||||
|
return Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::WriteZero,
|
||||||
|
"write zero byte into writer",
|
||||||
|
)));
|
||||||
|
} else {
|
||||||
|
self.pos += i;
|
||||||
|
self.amt += i as u64;
|
||||||
|
self.need_flush = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If pos larger than cap, this loop will never stop.
|
||||||
|
// In particular, user's wrong poll_write implementation returning
|
||||||
|
// incorrect written length may lead to thread blocking.
|
||||||
|
debug_assert!(
|
||||||
|
self.pos <= self.cap,
|
||||||
|
"writer returned length larger than input slice"
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we've written all the data and we've seen EOF, flush out the
|
||||||
|
// data and finish the transfer.
|
||||||
|
if self.pos == self.cap && self.read_done {
|
||||||
|
ready!(writer.as_mut().poll_flush(cx))?;
|
||||||
|
return Poll::Ready(Ok(self.amt));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_early_termination_a_to_d() {
|
||||||
|
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||||
|
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||||
|
|
||||||
|
// Simulate 'a' finishing while there's still data for 'b'
|
||||||
|
a_mock.write_all(b"hello").await.unwrap();
|
||||||
|
a_mock.shutdown().await.unwrap();
|
||||||
|
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||||
|
|
||||||
|
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||||
|
|
||||||
|
// Assert correct transferred amounts
|
||||||
|
let (a_to_d_count, d_to_a_count) = result;
|
||||||
|
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
|
||||||
|
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_early_termination_d_to_a() {
|
||||||
|
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||||
|
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||||
|
|
||||||
|
// Simulate 'a' finishing while there's still data for 'b'
|
||||||
|
d_mock.write_all(b"hello").await.unwrap();
|
||||||
|
d_mock.shutdown().await.unwrap();
|
||||||
|
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||||
|
|
||||||
|
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||||
|
|
||||||
|
// Assert correct transferred amounts
|
||||||
|
let (a_to_d_count, d_to_a_count) = result;
|
||||||
|
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
|
||||||
|
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
|
cancellation,
|
||||||
compute::PostgresConnection,
|
compute::PostgresConnection,
|
||||||
console::messages::MetricsAuxInfo,
|
console::messages::MetricsAuxInfo,
|
||||||
metrics::NUM_BYTES_PROXIED_COUNTER,
|
metrics::NUM_BYTES_PROXIED_COUNTER,
|
||||||
@@ -45,7 +46,7 @@ pub async fn proxy_pass(
|
|||||||
|
|
||||||
// Starting from here we only proxy the client's traffic.
|
// Starting from here we only proxy the client's traffic.
|
||||||
info!("performing the proxy pass...");
|
info!("performing the proxy pass...");
|
||||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -57,6 +58,7 @@ pub struct ProxyPassthrough<S> {
|
|||||||
|
|
||||||
pub req: IntCounterPairGuard,
|
pub req: IntCounterPairGuard,
|
||||||
pub conn: IntCounterPairGuard,
|
pub conn: IntCounterPairGuard,
|
||||||
|
pub cancel: cancellation::Session,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||||
|
|||||||
@@ -2,13 +2,19 @@
|
|||||||
|
|
||||||
mod mitm;
|
mod mitm;
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::connect_compute::ConnectMechanism;
|
use super::connect_compute::ConnectMechanism;
|
||||||
use super::retry::ShouldRetry;
|
use super::retry::ShouldRetry;
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
|
use crate::auth::backend::{
|
||||||
|
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||||
|
};
|
||||||
use crate::config::CertResolver;
|
use crate::config::CertResolver;
|
||||||
|
use crate::console::caches::NodeInfoCache;
|
||||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||||
|
use crate::error::ErrorKind;
|
||||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||||
use crate::{auth, http, sasl, scram};
|
use crate::{auth, http, sasl, scram};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -144,7 +150,7 @@ impl TestAuth for Scram {
|
|||||||
stream: &mut PqStream<Stream<S>>,
|
stream: &mut PqStream<Stream<S>>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let outcome = auth::AuthFlow::new(stream)
|
let outcome = auth::AuthFlow::new(stream)
|
||||||
.begin(auth::Scram(&self.0))
|
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
|
||||||
.await?
|
.await?
|
||||||
.authenticate()
|
.authenticate()
|
||||||
.await?;
|
.await?;
|
||||||
@@ -375,6 +381,7 @@ enum ConnectAction {
|
|||||||
struct TestConnectMechanism {
|
struct TestConnectMechanism {
|
||||||
counter: Arc<std::sync::Mutex<usize>>,
|
counter: Arc<std::sync::Mutex<usize>>,
|
||||||
sequence: Vec<ConnectAction>,
|
sequence: Vec<ConnectAction>,
|
||||||
|
cache: &'static NodeInfoCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestConnectMechanism {
|
impl TestConnectMechanism {
|
||||||
@@ -393,6 +400,12 @@ impl TestConnectMechanism {
|
|||||||
Self {
|
Self {
|
||||||
counter: Arc::new(std::sync::Mutex::new(0)),
|
counter: Arc::new(std::sync::Mutex::new(0)),
|
||||||
sequence,
|
sequence,
|
||||||
|
cache: Box::leak(Box::new(NodeInfoCache::new(
|
||||||
|
"test",
|
||||||
|
1,
|
||||||
|
Duration::from_secs(100),
|
||||||
|
false,
|
||||||
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -403,6 +416,13 @@ struct TestConnection;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct TestConnectError {
|
struct TestConnectError {
|
||||||
retryable: bool,
|
retryable: bool,
|
||||||
|
kind: crate::error::ErrorKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReportableError for TestConnectError {
|
||||||
|
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||||
|
self.kind
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for TestConnectError {
|
impl std::fmt::Display for TestConnectError {
|
||||||
@@ -436,8 +456,14 @@ impl ConnectMechanism for TestConnectMechanism {
|
|||||||
*counter += 1;
|
*counter += 1;
|
||||||
match action {
|
match action {
|
||||||
ConnectAction::Connect => Ok(TestConnection),
|
ConnectAction::Connect => Ok(TestConnection),
|
||||||
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
|
ConnectAction::Retry => Err(TestConnectError {
|
||||||
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
|
retryable: true,
|
||||||
|
kind: ErrorKind::Compute,
|
||||||
|
}),
|
||||||
|
ConnectAction::Fail => Err(TestConnectError {
|
||||||
|
retryable: false,
|
||||||
|
kind: ErrorKind::Compute,
|
||||||
|
}),
|
||||||
x => panic!("expecting action {:?}, connect is called instead", x),
|
x => panic!("expecting action {:?}, connect is called instead", x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -451,7 +477,7 @@ impl TestBackend for TestConnectMechanism {
|
|||||||
let action = self.sequence[*counter];
|
let action = self.sequence[*counter];
|
||||||
*counter += 1;
|
*counter += 1;
|
||||||
match action {
|
match action {
|
||||||
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
|
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||||
ConnectAction::WakeFail => {
|
ConnectAction::WakeFail => {
|
||||||
let err = console::errors::ApiError::Console {
|
let err = console::errors::ApiError::Console {
|
||||||
status: http::StatusCode::FORBIDDEN,
|
status: http::StatusCode::FORBIDDEN,
|
||||||
@@ -483,37 +509,41 @@ impl TestBackend for TestConnectMechanism {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn helper_create_cached_node_info() -> CachedNodeInfo {
|
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||||
let node = NodeInfo {
|
let node = NodeInfo {
|
||||||
config: compute::ConnCfg::new(),
|
config: compute::ConnCfg::new(),
|
||||||
aux: Default::default(),
|
aux: Default::default(),
|
||||||
allow_self_signed_compute: false,
|
allow_self_signed_compute: false,
|
||||||
};
|
};
|
||||||
CachedNodeInfo::new_uncached(node)
|
let (_, node) = cache.insert("key".into(), node);
|
||||||
|
node
|
||||||
}
|
}
|
||||||
|
|
||||||
fn helper_create_connect_info(
|
fn helper_create_connect_info(
|
||||||
mechanism: &TestConnectMechanism,
|
mechanism: &TestConnectMechanism,
|
||||||
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
|
) -> auth::BackendType<'static, ComputeCredentials, &()> {
|
||||||
let cache = helper_create_cached_node_info();
|
|
||||||
let user_info = auth::BackendType::Console(
|
let user_info = auth::BackendType::Console(
|
||||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||||
ComputeUserInfo {
|
ComputeCredentials {
|
||||||
endpoint: "endpoint".into(),
|
info: ComputeUserInfo {
|
||||||
user: "user".into(),
|
endpoint: "endpoint".into(),
|
||||||
options: NeonOptions::parse_options_raw(""),
|
user: "user".into(),
|
||||||
|
options: NeonOptions::parse_options_raw(""),
|
||||||
|
},
|
||||||
|
keys: ComputeCredentialKeys::Password("password".into()),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
(cache, user_info)
|
user_info
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn connect_to_compute_success() {
|
async fn connect_to_compute_success() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -521,11 +551,12 @@ async fn connect_to_compute_success() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn connect_to_compute_retry() {
|
async fn connect_to_compute_retry() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -534,11 +565,12 @@ async fn connect_to_compute_retry() {
|
|||||||
/// Test that we don't retry if the error is not retryable.
|
/// Test that we don't retry if the error is not retryable.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn connect_to_compute_non_retry_1() {
|
async fn connect_to_compute_non_retry_1() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap_err();
|
.unwrap_err();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -547,11 +579,12 @@ async fn connect_to_compute_non_retry_1() {
|
|||||||
/// Even for non-retryable errors, we should retry at least once.
|
/// Even for non-retryable errors, we should retry at least once.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn connect_to_compute_non_retry_2() {
|
async fn connect_to_compute_non_retry_2() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -560,15 +593,16 @@ async fn connect_to_compute_non_retry_2() {
|
|||||||
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn connect_to_compute_non_retry_3() {
|
async fn connect_to_compute_non_retry_3() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![
|
let mechanism = TestConnectMechanism::new(vec![
|
||||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||||
]);
|
]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap_err();
|
.unwrap_err();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -577,11 +611,12 @@ async fn connect_to_compute_non_retry_3() {
|
|||||||
/// Should retry wake compute.
|
/// Should retry wake compute.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn wake_retry() {
|
async fn wake_retry() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
@@ -590,11 +625,12 @@ async fn wake_retry() {
|
|||||||
/// Wake failed with a non-retryable error.
|
/// Wake failed with a non-retryable error.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn wake_non_retry() {
|
async fn wake_non_retry() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
use ConnectAction::*;
|
use ConnectAction::*;
|
||||||
let mut ctx = RequestMonitoring::test();
|
let mut ctx = RequestMonitoring::test();
|
||||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
let user_info = helper_create_connect_info(&mechanism);
|
||||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||||
.await
|
.await
|
||||||
.unwrap_err();
|
.unwrap_err();
|
||||||
mechanism.verify();
|
mechanism.verify();
|
||||||
|
|||||||
@@ -1,9 +1,4 @@
|
|||||||
use crate::auth::backend::ComputeUserInfo;
|
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||||
use crate::console::{
|
|
||||||
errors::WakeComputeError,
|
|
||||||
provider::{CachedNodeInfo, ConsoleBackend},
|
|
||||||
Api,
|
|
||||||
};
|
|
||||||
use crate::context::RequestMonitoring;
|
use crate::context::RequestMonitoring;
|
||||||
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
|
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
|
||||||
use crate::proxy::retry::retry_after;
|
use crate::proxy::retry::retry_after;
|
||||||
@@ -11,17 +6,16 @@ use hyper::StatusCode;
|
|||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
|
|
||||||
|
use super::connect_compute::ComputeConnectBackend;
|
||||||
use super::retry::ShouldRetry;
|
use super::retry::ShouldRetry;
|
||||||
|
|
||||||
/// wake a compute (or retrieve an existing compute session from cache)
|
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||||
pub async fn wake_compute(
|
|
||||||
num_retries: &mut u32,
|
num_retries: &mut u32,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
api: &ConsoleBackend,
|
api: &B,
|
||||||
info: &ComputeUserInfo,
|
|
||||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||||
loop {
|
loop {
|
||||||
let wake_res = api.wake_compute(ctx, info).await;
|
let wake_res = api.wake_compute(ctx).await;
|
||||||
match handle_try_wake(wake_res, *num_retries) {
|
match handle_try_wake(wake_res, *num_retries) {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ mod limiter;
|
|||||||
pub use aimd::Aimd;
|
pub use aimd::Aimd;
|
||||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||||
pub use limiter::Limiter;
|
pub use limiter::Limiter;
|
||||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo};
|
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||||
|
|||||||
@@ -22,6 +22,44 @@ use super::{
|
|||||||
RateLimiterConfig,
|
RateLimiterConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub struct RedisRateLimiter {
|
||||||
|
data: Vec<RateBucket>,
|
||||||
|
info: &'static [RateBucketInfo],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RedisRateLimiter {
|
||||||
|
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||||
|
Self {
|
||||||
|
data: vec![
|
||||||
|
RateBucket {
|
||||||
|
start: Instant::now(),
|
||||||
|
count: 0,
|
||||||
|
};
|
||||||
|
info.len()
|
||||||
|
],
|
||||||
|
info,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check that number of connections is below `max_rps` rps.
|
||||||
|
pub fn check(&mut self) -> bool {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let should_allow_request = self
|
||||||
|
.data
|
||||||
|
.iter_mut()
|
||||||
|
.zip(self.info)
|
||||||
|
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||||
|
|
||||||
|
if should_allow_request {
|
||||||
|
// only increment the bucket counts if the request will actually be accepted
|
||||||
|
self.data.iter_mut().for_each(RateBucket::inc);
|
||||||
|
}
|
||||||
|
|
||||||
|
should_allow_request
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Simple per-endpoint rate limiter.
|
// Simple per-endpoint rate limiter.
|
||||||
//
|
//
|
||||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
pub mod notifications;
|
pub mod notifications;
|
||||||
|
pub mod publisher;
|
||||||
|
|||||||
@@ -1,38 +1,44 @@
|
|||||||
use std::{convert::Infallible, sync::Arc};
|
use std::{convert::Infallible, sync::Arc};
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use pq_proto::CancelKeyData;
|
||||||
use redis::aio::PubSub;
|
use redis::aio::PubSub;
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cache::project_info::ProjectInfoCache,
|
cache::project_info::ProjectInfoCache,
|
||||||
|
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
|
||||||
intern::{ProjectIdInt, RoleNameInt},
|
intern::{ProjectIdInt, RoleNameInt},
|
||||||
};
|
};
|
||||||
|
|
||||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||||
|
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||||
|
|
||||||
struct ConsoleRedisClient {
|
struct RedisConsumerClient {
|
||||||
client: redis::Client,
|
client: redis::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConsoleRedisClient {
|
impl RedisConsumerClient {
|
||||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||||
let client = redis::Client::open(url)?;
|
let client = redis::Client::open(url)?;
|
||||||
Ok(Self { client })
|
Ok(Self { client })
|
||||||
}
|
}
|
||||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||||
conn.subscribe(CHANNEL_NAME).await?;
|
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||||
|
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||||
|
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||||
Ok(conn)
|
Ok(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||||
#[serde(tag = "topic", content = "data")]
|
#[serde(tag = "topic", content = "data")]
|
||||||
enum Notification {
|
pub(crate) enum Notification {
|
||||||
#[serde(
|
#[serde(
|
||||||
rename = "/allowed_ips_updated",
|
rename = "/allowed_ips_updated",
|
||||||
deserialize_with = "deserialize_json_string"
|
deserialize_with = "deserialize_json_string"
|
||||||
@@ -45,16 +51,25 @@ enum Notification {
|
|||||||
deserialize_with = "deserialize_json_string"
|
deserialize_with = "deserialize_json_string"
|
||||||
)]
|
)]
|
||||||
PasswordUpdate { password_update: PasswordUpdate },
|
PasswordUpdate { password_update: PasswordUpdate },
|
||||||
|
#[serde(rename = "/cancel_session")]
|
||||||
|
Cancel(CancelSession),
|
||||||
}
|
}
|
||||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||||
struct AllowedIpsUpdate {
|
pub(crate) struct AllowedIpsUpdate {
|
||||||
project_id: ProjectIdInt,
|
project_id: ProjectIdInt,
|
||||||
}
|
}
|
||||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||||
struct PasswordUpdate {
|
pub(crate) struct PasswordUpdate {
|
||||||
project_id: ProjectIdInt,
|
project_id: ProjectIdInt,
|
||||||
role_name: RoleNameInt,
|
role_name: RoleNameInt,
|
||||||
}
|
}
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||||
|
pub(crate) struct CancelSession {
|
||||||
|
pub region_id: Option<String>,
|
||||||
|
pub cancel_key_data: CancelKeyData,
|
||||||
|
pub session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||||
where
|
where
|
||||||
T: for<'de2> serde::Deserialize<'de2>,
|
T: for<'de2> serde::Deserialize<'de2>,
|
||||||
@@ -64,6 +79,88 @@ where
|
|||||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MessageHandler<
|
||||||
|
C: ProjectInfoCache + Send + Sync + 'static,
|
||||||
|
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||||
|
> {
|
||||||
|
cache: Arc<C>,
|
||||||
|
cancellation_handler: Arc<H>,
|
||||||
|
region_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
C: ProjectInfoCache + Send + Sync + 'static,
|
||||||
|
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||||
|
> MessageHandler<C, H>
|
||||||
|
{
|
||||||
|
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
|
||||||
|
Self {
|
||||||
|
cache,
|
||||||
|
cancellation_handler,
|
||||||
|
region_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn disable_ttl(&self) {
|
||||||
|
self.cache.disable_ttl();
|
||||||
|
}
|
||||||
|
pub fn enable_ttl(&self) {
|
||||||
|
self.cache.enable_ttl();
|
||||||
|
}
|
||||||
|
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||||
|
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||||
|
use Notification::*;
|
||||||
|
let payload: String = msg.get_payload()?;
|
||||||
|
tracing::debug!(?payload, "received a message payload");
|
||||||
|
|
||||||
|
let msg: Notification = match serde_json::from_str(&payload) {
|
||||||
|
Ok(msg) => msg,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("broken message: {e}");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::debug!(?msg, "received a message");
|
||||||
|
match msg {
|
||||||
|
Cancel(cancel_session) => {
|
||||||
|
tracing::Span::current().record(
|
||||||
|
"session_id",
|
||||||
|
&tracing::field::display(cancel_session.session_id),
|
||||||
|
);
|
||||||
|
if let Some(cancel_region) = cancel_session.region_id {
|
||||||
|
// If the message is not for this region, ignore it.
|
||||||
|
if cancel_region != self.region_id {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||||
|
match self
|
||||||
|
.cancellation_handler
|
||||||
|
.cancel_session_no_publish(cancel_session.cancel_key_data)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("failed to cancel session: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
invalidate_cache(self.cache.clone(), msg.clone());
|
||||||
|
// It might happen that the invalid entry is on the way to be cached.
|
||||||
|
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||||
|
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||||
|
let cache = self.cache.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||||
|
invalidate_cache(cache, msg);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||||
use Notification::*;
|
use Notification::*;
|
||||||
match msg {
|
match msg {
|
||||||
@@ -74,50 +171,33 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
|||||||
password_update.project_id,
|
password_update.project_id,
|
||||||
password_update.role_name,
|
password_update.role_name,
|
||||||
),
|
),
|
||||||
|
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(cache))]
|
|
||||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
|
||||||
where
|
|
||||||
C: ProjectInfoCache + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
let payload: String = msg.get_payload()?;
|
|
||||||
tracing::debug!(?payload, "received a message payload");
|
|
||||||
|
|
||||||
let msg: Notification = match serde_json::from_str(&payload) {
|
|
||||||
Ok(msg) => msg,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("broken message: {e}");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::debug!(?msg, "received a message");
|
|
||||||
invalidate_cache(cache.clone(), msg.clone());
|
|
||||||
// It might happen that the invalid entry is on the way to be cached.
|
|
||||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
|
||||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
|
||||||
invalidate_cache(cache, msg.clone());
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle console's invalidation messages.
|
/// Handle console's invalidation messages.
|
||||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
pub async fn task_main<C>(
|
||||||
|
url: String,
|
||||||
|
cache: Arc<C>,
|
||||||
|
cancel_map: CancelMap,
|
||||||
|
region_id: String,
|
||||||
|
) -> anyhow::Result<Infallible>
|
||||||
where
|
where
|
||||||
C: ProjectInfoCache + Send + Sync + 'static,
|
C: ProjectInfoCache + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
cache.enable_ttl();
|
cache.enable_ttl();
|
||||||
|
let handler = MessageHandler::new(
|
||||||
|
cache,
|
||||||
|
Arc::new(CancellationHandler::new(cancel_map, None)),
|
||||||
|
region_id,
|
||||||
|
);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let redis = ConsoleRedisClient::new(&url)?;
|
let redis = RedisConsumerClient::new(&url)?;
|
||||||
let conn = match redis.try_connect().await {
|
let conn = match redis.try_connect().await {
|
||||||
Ok(conn) => {
|
Ok(conn) => {
|
||||||
cache.disable_ttl();
|
handler.disable_ttl();
|
||||||
conn
|
conn
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -130,7 +210,7 @@ where
|
|||||||
};
|
};
|
||||||
let mut stream = conn.into_on_message();
|
let mut stream = conn.into_on_message();
|
||||||
while let Some(msg) = stream.next().await {
|
while let Some(msg) = stream.next().await {
|
||||||
match handle_message(msg, cache.clone()) {
|
match handler.handle_message(msg).await {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||||
@@ -138,7 +218,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.enable_ttl();
|
handler.enable_ttl();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,6 +278,33 @@ mod tests {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||||
|
let cancel_key_data = CancelKeyData {
|
||||||
|
backend_pid: 42,
|
||||||
|
cancel_key: 41,
|
||||||
|
};
|
||||||
|
let uuid = uuid::Uuid::new_v4();
|
||||||
|
let msg = Notification::Cancel(CancelSession {
|
||||||
|
cancel_key_data,
|
||||||
|
region_id: None,
|
||||||
|
session_id: uuid,
|
||||||
|
});
|
||||||
|
let text = serde_json::to_string(&msg)?;
|
||||||
|
let result: Notification = serde_json::from_str(&text)?;
|
||||||
|
assert_eq!(msg, result);
|
||||||
|
|
||||||
|
let msg = Notification::Cancel(CancelSession {
|
||||||
|
cancel_key_data,
|
||||||
|
region_id: Some("region".to_string()),
|
||||||
|
session_id: uuid,
|
||||||
|
});
|
||||||
|
let text = serde_json::to_string(&msg)?;
|
||||||
|
let result: Notification = serde_json::from_str(&text)?;
|
||||||
|
assert_eq!(msg, result,);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
80
proxy/src/redis/publisher.rs
Normal file
80
proxy/src/redis/publisher.rs
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
use pq_proto::CancelKeyData;
|
||||||
|
use redis::AsyncCommands;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
|
||||||
|
|
||||||
|
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
|
||||||
|
|
||||||
|
pub struct RedisPublisherClient {
|
||||||
|
client: redis::Client,
|
||||||
|
publisher: Option<redis::aio::Connection>,
|
||||||
|
region_id: String,
|
||||||
|
limiter: RedisRateLimiter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RedisPublisherClient {
|
||||||
|
pub fn new(
|
||||||
|
url: &str,
|
||||||
|
region_id: String,
|
||||||
|
info: &'static [RateBucketInfo],
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
|
let client = redis::Client::open(url)?;
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
publisher: None,
|
||||||
|
region_id,
|
||||||
|
limiter: RedisRateLimiter::new(info),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
pub async fn try_publish(
|
||||||
|
&mut self,
|
||||||
|
cancel_key_data: CancelKeyData,
|
||||||
|
session_id: Uuid,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
if !self.limiter.check() {
|
||||||
|
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||||
|
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||||
|
}
|
||||||
|
match self.publish(cancel_key_data, session_id).await {
|
||||||
|
Ok(()) => return Ok(()),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("failed to publish a message: {e}");
|
||||||
|
self.publisher = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||||
|
self.try_connect().await?;
|
||||||
|
self.publish(cancel_key_data, session_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn publish(
|
||||||
|
&mut self,
|
||||||
|
cancel_key_data: CancelKeyData,
|
||||||
|
session_id: Uuid,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let conn = self
|
||||||
|
.publisher
|
||||||
|
.as_mut()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
|
||||||
|
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||||
|
region_id: Some(self.region_id.clone()),
|
||||||
|
cancel_key_data,
|
||||||
|
session_id,
|
||||||
|
}))?;
|
||||||
|
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||||
|
match self.client.get_async_connection().await {
|
||||||
|
Ok(conn) => {
|
||||||
|
self.publisher = Some(conn);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("failed to connect to redis: {e}");
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
|||||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||||
use crate::rate_limiter::EndpointRateLimiter;
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
use crate::serverless::backend::PoolingBackend;
|
use crate::serverless::backend::PoolingBackend;
|
||||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use hyper::{
|
use hyper::{
|
||||||
server::{
|
server::{
|
||||||
@@ -50,6 +50,7 @@ pub async fn task_main(
|
|||||||
ws_listener: TcpListener,
|
ws_listener: TcpListener,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
scopeguard::defer! {
|
scopeguard::defer! {
|
||||||
info!("websocket server has shut down");
|
info!("websocket server has shut down");
|
||||||
@@ -115,7 +116,7 @@ pub async fn task_main(
|
|||||||
let backend = backend.clone();
|
let backend = backend.clone();
|
||||||
let ws_connections = ws_connections.clone();
|
let ws_connections = ws_connections.clone();
|
||||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||||
|
let cancellation_handler = cancellation_handler.clone();
|
||||||
async move {
|
async move {
|
||||||
let peer_addr = match client_addr {
|
let peer_addr = match client_addr {
|
||||||
Some(addr) => addr,
|
Some(addr) => addr,
|
||||||
@@ -127,9 +128,9 @@ pub async fn task_main(
|
|||||||
let backend = backend.clone();
|
let backend = backend.clone();
|
||||||
let ws_connections = ws_connections.clone();
|
let ws_connections = ws_connections.clone();
|
||||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||||
|
let cancellation_handler = cancellation_handler.clone();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let cancel_map = Arc::new(CancelMap::default());
|
|
||||||
let session_id = uuid::Uuid::new_v4();
|
let session_id = uuid::Uuid::new_v4();
|
||||||
|
|
||||||
request_handler(
|
request_handler(
|
||||||
@@ -137,7 +138,7 @@ pub async fn task_main(
|
|||||||
config,
|
config,
|
||||||
backend,
|
backend,
|
||||||
ws_connections,
|
ws_connections,
|
||||||
cancel_map,
|
cancellation_handler,
|
||||||
session_id,
|
session_id,
|
||||||
peer_addr.ip(),
|
peer_addr.ip(),
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
@@ -205,7 +206,7 @@ async fn request_handler(
|
|||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
ws_connections: TaskTracker,
|
ws_connections: TaskTracker,
|
||||||
cancel_map: Arc<CancelMap>,
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
peer_addr: IpAddr,
|
peer_addr: IpAddr,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
@@ -232,7 +233,7 @@ async fn request_handler(
|
|||||||
config,
|
config,
|
||||||
ctx,
|
ctx,
|
||||||
websocket,
|
websocket,
|
||||||
cancel_map,
|
cancellation_handler,
|
||||||
host,
|
host,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
use std::{sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tracing::info;
|
use tracing::{field::display, info};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
|
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||||
compute,
|
compute,
|
||||||
config::ProxyConfig,
|
config::ProxyConfig,
|
||||||
console::{
|
console::{
|
||||||
@@ -15,7 +15,7 @@ use crate::{
|
|||||||
proxy::connect_compute::ConnectMechanism,
|
proxy::connect_compute::ConnectMechanism,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME};
|
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||||
|
|
||||||
pub struct PoolingBackend {
|
pub struct PoolingBackend {
|
||||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||||
@@ -27,7 +27,7 @@ impl PoolingBackend {
|
|||||||
&self,
|
&self,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
conn_info: &ConnInfo,
|
conn_info: &ConnInfo,
|
||||||
) -> Result<ComputeCredentialKeys, AuthError> {
|
) -> Result<ComputeCredentials, AuthError> {
|
||||||
let user_info = conn_info.user_info.clone();
|
let user_info = conn_info.user_info.clone();
|
||||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||||
@@ -49,13 +49,17 @@ impl PoolingBackend {
|
|||||||
};
|
};
|
||||||
let auth_outcome =
|
let auth_outcome =
|
||||||
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
|
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
|
||||||
match auth_outcome {
|
let res = match auth_outcome {
|
||||||
crate::sasl::Outcome::Success(key) => Ok(key),
|
crate::sasl::Outcome::Success(key) => Ok(key),
|
||||||
crate::sasl::Outcome::Failure(reason) => {
|
crate::sasl::Outcome::Failure(reason) => {
|
||||||
info!("auth backend failed with an error: {reason}");
|
info!("auth backend failed with an error: {reason}");
|
||||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
res.map(|key| ComputeCredentials {
|
||||||
|
info: user_info,
|
||||||
|
keys: key,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wake up the destination if needed. Code here is a bit involved because
|
// Wake up the destination if needed. Code here is a bit involved because
|
||||||
@@ -66,7 +70,7 @@ impl PoolingBackend {
|
|||||||
&self,
|
&self,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
conn_info: ConnInfo,
|
conn_info: ConnInfo,
|
||||||
keys: ComputeCredentialKeys,
|
keys: ComputeCredentials,
|
||||||
force_new: bool,
|
force_new: bool,
|
||||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||||
let maybe_client = if !force_new {
|
let maybe_client = if !force_new {
|
||||||
@@ -81,27 +85,9 @@ impl PoolingBackend {
|
|||||||
return Ok(client);
|
return Ok(client);
|
||||||
}
|
}
|
||||||
let conn_id = uuid::Uuid::new_v4();
|
let conn_id = uuid::Uuid::new_v4();
|
||||||
|
tracing::Span::current().record("conn_id", display(conn_id));
|
||||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||||
ctx.set_application(Some(APP_NAME));
|
let backend = self.config.auth_backend.as_ref().map(|_| keys);
|
||||||
let backend = self
|
|
||||||
.config
|
|
||||||
.auth_backend
|
|
||||||
.as_ref()
|
|
||||||
.map(|_| conn_info.user_info.clone());
|
|
||||||
|
|
||||||
let mut node_info = backend
|
|
||||||
.wake_compute(ctx)
|
|
||||||
.await?
|
|
||||||
.ok_or(HttpConnError::NoComputeInfo)?;
|
|
||||||
|
|
||||||
match keys {
|
|
||||||
#[cfg(any(test, feature = "testing"))]
|
|
||||||
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
|
|
||||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.set_project(node_info.aux.clone());
|
|
||||||
|
|
||||||
crate::proxy::connect_compute::connect_to_compute(
|
crate::proxy::connect_compute::connect_to_compute(
|
||||||
ctx,
|
ctx,
|
||||||
&TokioMechanism {
|
&TokioMechanism {
|
||||||
@@ -109,8 +95,8 @@ impl PoolingBackend {
|
|||||||
conn_info,
|
conn_info,
|
||||||
pool: self.pool.clone(),
|
pool: self.pool.clone(),
|
||||||
},
|
},
|
||||||
node_info,
|
|
||||||
&backend,
|
&backend,
|
||||||
|
false, // do not allow self signed compute for http flow
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@@ -129,8 +115,6 @@ pub enum HttpConnError {
|
|||||||
AuthError(#[from] AuthError),
|
AuthError(#[from] AuthError),
|
||||||
#[error("wake_compute returned error")]
|
#[error("wake_compute returned error")]
|
||||||
WakeCompute(#[from] WakeComputeError),
|
WakeCompute(#[from] WakeComputeError),
|
||||||
#[error("wake_compute returned nothing")]
|
|
||||||
NoComputeInfo,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TokioMechanism {
|
struct TokioMechanism {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard;
|
|||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use smallvec::SmallVec;
|
use smallvec::SmallVec;
|
||||||
use smol_str::SmolStr;
|
|
||||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||||
use std::{
|
use std::{
|
||||||
fmt,
|
fmt,
|
||||||
@@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument};
|
|||||||
|
|
||||||
use super::backend::HttpConnError;
|
use super::backend::HttpConnError;
|
||||||
|
|
||||||
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ConnInfo {
|
pub struct ConnInfo {
|
||||||
pub user_info: ComputeUserInfo,
|
pub user_info: ComputeUserInfo,
|
||||||
@@ -379,12 +376,13 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
|||||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
} else {
|
} else {
|
||||||
info!("pool: reusing connection '{conn_info}'");
|
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||||
client.session.send(ctx.session_id)?;
|
|
||||||
tracing::Span::current().record(
|
tracing::Span::current().record(
|
||||||
"pid",
|
"pid",
|
||||||
&tracing::field::display(client.inner.get_process_id()),
|
&tracing::field::display(client.inner.get_process_id()),
|
||||||
);
|
);
|
||||||
|
info!("pool: reusing connection '{conn_info}'");
|
||||||
|
client.session.send(ctx.session_id)?;
|
||||||
ctx.latency_timer.pool_hit();
|
ctx.latency_timer.pool_hit();
|
||||||
ctx.latency_timer.success();
|
ctx.latency_timer.success();
|
||||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||||
@@ -577,7 +575,6 @@ pub struct Client<C: ClientInnerExt> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct Discard<'a, C: ClientInnerExt> {
|
pub struct Discard<'a, C: ClientInnerExt> {
|
||||||
conn_id: uuid::Uuid,
|
|
||||||
conn_info: &'a ConnInfo,
|
conn_info: &'a ConnInfo,
|
||||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||||
}
|
}
|
||||||
@@ -603,14 +600,7 @@ impl<C: ClientInnerExt> Client<C> {
|
|||||||
span: _,
|
span: _,
|
||||||
} = self;
|
} = self;
|
||||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||||
(
|
(&mut inner.inner, Discard { pool, conn_info })
|
||||||
&mut inner.inner,
|
|
||||||
Discard {
|
|
||||||
pool,
|
|
||||||
conn_info,
|
|
||||||
conn_id: inner.conn_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||||
@@ -625,13 +615,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
|
|||||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||||
let conn_info = &self.conn_info;
|
let conn_info = &self.conn_info;
|
||||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
|
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn discard(&mut self) {
|
pub fn discard(&mut self) {
|
||||||
let conn_info = &self.conn_info;
|
let conn_info = &self.conn_info;
|
||||||
if std::mem::take(self.pool).strong_count() > 0 {
|
if std::mem::take(self.pool).strong_count() > 0 {
|
||||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ use crate::error::ReportableError;
|
|||||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||||
use crate::proxy::NeonOptions;
|
use crate::proxy::NeonOptions;
|
||||||
|
use crate::serverless::backend::HttpConnError;
|
||||||
|
use crate::DbName;
|
||||||
use crate::RoleName;
|
use crate::RoleName;
|
||||||
|
|
||||||
use super::backend::PoolingBackend;
|
use super::backend::PoolingBackend;
|
||||||
@@ -117,6 +119,9 @@ fn get_conn_info(
|
|||||||
headers: &HeaderMap,
|
headers: &HeaderMap,
|
||||||
tls: &TlsConfig,
|
tls: &TlsConfig,
|
||||||
) -> Result<ConnInfo, ConnInfoError> {
|
) -> Result<ConnInfo, ConnInfoError> {
|
||||||
|
// HTTP only uses cleartext (for now and likely always)
|
||||||
|
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||||
|
|
||||||
let connection_string = headers
|
let connection_string = headers
|
||||||
.get("Neon-Connection-String")
|
.get("Neon-Connection-String")
|
||||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||||
@@ -134,7 +139,8 @@ fn get_conn_info(
|
|||||||
.path_segments()
|
.path_segments()
|
||||||
.ok_or(ConnInfoError::MissingDbName)?;
|
.ok_or(ConnInfoError::MissingDbName)?;
|
||||||
|
|
||||||
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
|
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
|
||||||
|
ctx.set_dbname(dbname.clone());
|
||||||
|
|
||||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||||
if username.is_empty() {
|
if username.is_empty() {
|
||||||
@@ -174,7 +180,7 @@ fn get_conn_info(
|
|||||||
|
|
||||||
Ok(ConnInfo {
|
Ok(ConnInfo {
|
||||||
user_info,
|
user_info,
|
||||||
dbname: dbname.into(),
|
dbname,
|
||||||
password: match password {
|
password: match password {
|
||||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||||
std::borrow::Cow::Owned(b) => b.into(),
|
std::borrow::Cow::Owned(b) => b.into(),
|
||||||
@@ -300,7 +306,14 @@ pub async fn handle(
|
|||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
#[instrument(
|
||||||
|
name = "sql-over-http",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
pid = tracing::field::Empty,
|
||||||
|
conn_id = tracing::field::Empty
|
||||||
|
)
|
||||||
|
)]
|
||||||
async fn handle_inner(
|
async fn handle_inner(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: &mut RequestMonitoring,
|
||||||
@@ -354,12 +367,10 @@ async fn handle_inner(
|
|||||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||||
|
|
||||||
let paused = ctx.latency_timer.pause();
|
|
||||||
let request_content_length = match request.body().size_hint().upper() {
|
let request_content_length = match request.body().size_hint().upper() {
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => MAX_REQUEST_SIZE + 1,
|
None => MAX_REQUEST_SIZE + 1,
|
||||||
};
|
};
|
||||||
drop(paused);
|
|
||||||
info!(request_content_length, "request size in bytes");
|
info!(request_content_length, "request size in bytes");
|
||||||
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
|
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
|
||||||
|
|
||||||
@@ -375,15 +386,20 @@ async fn handle_inner(
|
|||||||
let body = hyper::body::to_bytes(request.into_body())
|
let body = hyper::body::to_bytes(request.into_body())
|
||||||
.await
|
.await
|
||||||
.map_err(anyhow::Error::from)?;
|
.map_err(anyhow::Error::from)?;
|
||||||
|
info!(length = body.len(), "request payload read");
|
||||||
let payload: Payload = serde_json::from_slice(&body)?;
|
let payload: Payload = serde_json::from_slice(&body)?;
|
||||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||||
};
|
};
|
||||||
|
|
||||||
let authenticate_and_connect = async {
|
let authenticate_and_connect = async {
|
||||||
let keys = backend.authenticate(ctx, &conn_info).await?;
|
let keys = backend.authenticate(ctx, &conn_info).await?;
|
||||||
backend
|
let client = backend
|
||||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||||
.await
|
.await?;
|
||||||
|
// not strictly necessary to mark success here,
|
||||||
|
// but it's just insurance for if we forget it somewhere else
|
||||||
|
ctx.latency_timer.success();
|
||||||
|
Ok::<_, HttpConnError>(client)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run both operations in parallel
|
// Run both operations in parallel
|
||||||
@@ -415,6 +431,7 @@ async fn handle_inner(
|
|||||||
results
|
results
|
||||||
}
|
}
|
||||||
Payload::Batch(statements) => {
|
Payload::Batch(statements) => {
|
||||||
|
info!("starting transaction");
|
||||||
let (inner, mut discard) = client.inner();
|
let (inner, mut discard) = client.inner();
|
||||||
let mut builder = inner.build_transaction();
|
let mut builder = inner.build_transaction();
|
||||||
if let Some(isolation_level) = txn_isolation_level {
|
if let Some(isolation_level) = txn_isolation_level {
|
||||||
@@ -444,6 +461,7 @@ async fn handle_inner(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(results) => {
|
Ok(results) => {
|
||||||
|
info!("commit");
|
||||||
let status = transaction.commit().await.map_err(|e| {
|
let status = transaction.commit().await.map_err(|e| {
|
||||||
// if we cannot commit - for now don't return connection to pool
|
// if we cannot commit - for now don't return connection to pool
|
||||||
// TODO: get a query status from the error
|
// TODO: get a query status from the error
|
||||||
@@ -454,6 +472,7 @@ async fn handle_inner(
|
|||||||
results
|
results
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
info!("rollback");
|
||||||
let status = transaction.rollback().await.map_err(|e| {
|
let status = transaction.rollback().await.map_err(|e| {
|
||||||
// if we cannot rollback - for now don't return connection to pool
|
// if we cannot rollback - for now don't return connection to pool
|
||||||
// TODO: get a query status from the error
|
// TODO: get a query status from the error
|
||||||
@@ -528,8 +547,10 @@ async fn query_to_json<T: GenericClient>(
|
|||||||
raw_output: bool,
|
raw_output: bool,
|
||||||
default_array_mode: bool,
|
default_array_mode: bool,
|
||||||
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
||||||
|
info!("executing query");
|
||||||
let query_params = data.params;
|
let query_params = data.params;
|
||||||
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
||||||
|
info!("finished executing query");
|
||||||
|
|
||||||
// Manually drain the stream into a vector to leave row_stream hanging
|
// Manually drain the stream into a vector to leave row_stream hanging
|
||||||
// around to get a command tag. Also check that the response is not too
|
// around to get a command tag. Also check that the response is not too
|
||||||
@@ -564,6 +585,13 @@ async fn query_to_json<T: GenericClient>(
|
|||||||
}
|
}
|
||||||
.and_then(|s| s.parse::<i64>().ok());
|
.and_then(|s| s.parse::<i64>().ok());
|
||||||
|
|
||||||
|
info!(
|
||||||
|
rows = rows.len(),
|
||||||
|
?ready,
|
||||||
|
command_tag,
|
||||||
|
"finished reading rows"
|
||||||
|
);
|
||||||
|
|
||||||
let mut fields = vec![];
|
let mut fields = vec![];
|
||||||
let mut columns = vec![];
|
let mut columns = vec![];
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
cancellation::CancelMap,
|
cancellation::CancellationHandler,
|
||||||
config::ProxyConfig,
|
config::ProxyConfig,
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
error::{io_error, ReportableError},
|
error::{io_error, ReportableError},
|
||||||
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
|
|||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
mut ctx: RequestMonitoring,
|
mut ctx: RequestMonitoring,
|
||||||
websocket: HyperWebsocket,
|
websocket: HyperWebsocket,
|
||||||
cancel_map: Arc<CancelMap>,
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
hostname: Option<String>,
|
hostname: Option<String>,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -141,7 +141,7 @@ pub async fn serve_websocket(
|
|||||||
let res = handle_client(
|
let res = handle_client(
|
||||||
config,
|
config,
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
cancel_map,
|
cancellation_handler,
|
||||||
WebSocketRw::new(websocket),
|
WebSocketRw::new(websocket),
|
||||||
ClientMode::Websockets { hostname },
|
ClientMode::Websockets { hostname },
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ futures-io = { version = "0.3" }
|
|||||||
futures-sink = { version = "0.3" }
|
futures-sink = { version = "0.3" }
|
||||||
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
||||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
|
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
|
||||||
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
|
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
|
||||||
hex = { version = "0.4", features = ["serde"] }
|
hex = { version = "0.4", features = ["serde"] }
|
||||||
hmac = { version = "0.12", default-features = false, features = ["reset"] }
|
hmac = { version = "0.12", default-features = false, features = ["reset"] }
|
||||||
@@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
|
|||||||
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
|
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
|
||||||
either = { version = "1" }
|
either = { version = "1" }
|
||||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
|
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
|
||||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||||
itertools = { version = "0.10" }
|
itertools = { version = "0.10" }
|
||||||
libc = { version = "0.2", features = ["extra_traits", "use_std"] }
|
libc = { version = "0.2", features = ["extra_traits", "use_std"] }
|
||||||
|
|||||||
Reference in New Issue
Block a user