mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: add request context for observability and blocking (#6160)
## Summary of changes ### RequestMonitoring We want to add an event stream with information on each request for easier analysis than what we can do with diagnostic logs alone (https://github.com/neondatabase/cloud/issues/8807). This RequestMonitoring will keep a record of the final state of a request. On drop it will be pushed into a queue to be uploaded. Because this context is a bag of data, I don't want this information to impact logic of request handling. I personally think that weakly typed data (such as all these options) makes for spaghetti code. I will however allow for this data to impact rate-limiting and blocking of requests, as this does not _really_ change how a request is handled. ### Parquet Each `RequestMonitoring` is flushed into a channel where it is converted into `RequestData`, which is accumulated into parquet files. Each file will have a certain number of rows per row group, and several row groups will eventually fill up the file, which we then upload to S3. We will also upload smaller files if they take too long to construct.
This commit is contained in:
148
Cargo.lock
generated
148
Cargo.lock
generated
@@ -30,6 +30,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd7d5a2cecb58716e47d67d5703a249964b14c7be1ec3cad3affc295b2d1c35d"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"const-random",
|
||||
"getrandom 0.2.11",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
@@ -50,6 +52,12 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@@ -247,6 +255,12 @@ dependencies = [
|
||||
"syn 2.0.32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba"
|
||||
|
||||
[[package]]
|
||||
name = "atomic-polyfill"
|
||||
version = "1.0.2"
|
||||
@@ -1011,17 +1025,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.24"
|
||||
version = "0.4.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b"
|
||||
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"winapi",
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2475,6 +2489,12 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "integer-encoding"
|
||||
version = "3.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02"
|
||||
|
||||
[[package]]
|
||||
name = "io-lifetimes"
|
||||
version = "1.0.11"
|
||||
@@ -2838,6 +2858,19 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.3"
|
||||
@@ -2849,6 +2882,15 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
@@ -2859,6 +2901,28 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
@@ -3081,6 +3145,15 @@ dependencies = [
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "2.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-multimap"
|
||||
version = "0.7.1"
|
||||
@@ -3339,6 +3412,35 @@ dependencies = [
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parquet"
|
||||
version = "49.0.0"
|
||||
source = "git+https://github.com/neondatabase/arrow-rs?branch=neon-fix-bugs#8a0bc58aa67b98aabbd8eee7c6ca4281967ff9e9"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"hashbrown 0.14.0",
|
||||
"num",
|
||||
"num-bigint",
|
||||
"paste",
|
||||
"seq-macro",
|
||||
"thrift",
|
||||
"twox-hash",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parquet_derive"
|
||||
version = "49.0.0"
|
||||
source = "git+https://github.com/neondatabase/arrow-rs?branch=neon-fix-bugs#8a0bc58aa67b98aabbd8eee7c6ca4281967ff9e9"
|
||||
dependencies = [
|
||||
"parquet",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.5.0"
|
||||
@@ -3762,6 +3864,8 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
"bstr",
|
||||
"bytes",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"chrono",
|
||||
"clap",
|
||||
"consumption_metrics",
|
||||
@@ -3784,6 +3888,8 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot 0.12.1",
|
||||
"parquet",
|
||||
"parquet_derive",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
@@ -3794,6 +3900,7 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"rcgen",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"reqwest-retry",
|
||||
@@ -4682,6 +4789,12 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "seq-macro"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.183"
|
||||
@@ -5202,6 +5315,17 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thrift"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"integer-encoding",
|
||||
"ordered-float",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.21"
|
||||
@@ -5746,6 +5870,16 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
@@ -5923,10 +6057,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.3.3"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2"
|
||||
checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560"
|
||||
dependencies = [
|
||||
"atomic",
|
||||
"getrandom 0.2.11",
|
||||
"serde",
|
||||
]
|
||||
@@ -6422,6 +6557,7 @@ dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"prost",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
|
||||
@@ -107,6 +107,8 @@ opentelemetry = "0.19.0"
|
||||
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||
opentelemetry-semantic-conventions = "0.11.0"
|
||||
parking_lot = "0.12"
|
||||
parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
|
||||
parquet_derive = "49.0.0"
|
||||
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
|
||||
pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
@@ -161,7 +163,7 @@ tracing-error = "0.2.0"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
|
||||
url = "2.2"
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
webpki-roots = "0.25"
|
||||
x509-parser = "0.15"
|
||||
@@ -215,6 +217,10 @@ tonic-build = "0.9"
|
||||
# TODO: we should probably fork `tokio-postgres-rustls` instead.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
|
||||
|
||||
# bug fixes for UUID
|
||||
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
|
||||
parquet_derive = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -5,7 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["testing"]
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
@@ -14,6 +14,7 @@ async-trait.workspace = true
|
||||
base64.workspace = true
|
||||
bstr.workspace = true
|
||||
bytes = { workspace = true, features = ["serde"] }
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
@@ -35,6 +36,8 @@ metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
parquet.workspace = true
|
||||
parquet_derive.workspace = true
|
||||
pbkdf2 = { workspace = true, features = ["simple", "std"] }
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
@@ -42,6 +45,7 @@ pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
reqwest-middleware.workspace = true
|
||||
reqwest-retry.workspace = true
|
||||
@@ -80,6 +84,7 @@ smol_str.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
camino-tempfile.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::handle_try_wake;
|
||||
use crate::proxy::retry::retry_after;
|
||||
use crate::scram;
|
||||
@@ -22,12 +23,10 @@ use crate::{
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
Api,
|
||||
},
|
||||
metrics::LatencyTimer,
|
||||
stream, url,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::borrow::Cow;
|
||||
use std::net::IpAddr;
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -129,7 +128,6 @@ pub struct ComputeCredentials<T> {
|
||||
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
pub cache_key: SmolStr,
|
||||
}
|
||||
|
||||
@@ -151,7 +149,6 @@ impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
|
||||
let inner = ComputeUserInfoNoEndpoint {
|
||||
user: creds.user,
|
||||
peer_addr: creds.peer_addr,
|
||||
cache_key: creds.cache_key,
|
||||
};
|
||||
match creds.project {
|
||||
@@ -166,33 +163,34 @@ impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: ClientCredentials,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match creds.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, latency_timer).await?;
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
|
||||
(res.info, Some(res.keys))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
info!("fetching user's authentication info");
|
||||
let allowed_ips = api.get_allowed_ips(extra, &info).await?;
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) {
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed());
|
||||
}
|
||||
let cached_secret = api.get_role_secret(extra, &info).await?;
|
||||
let cached_secret = api.get_role_secret(ctx, &info).await?;
|
||||
|
||||
let secret = cached_secret.clone().unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
@@ -202,13 +200,13 @@ async fn auth_quirks(
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random()))
|
||||
});
|
||||
match authenticate_with_secret(
|
||||
ctx,
|
||||
secret,
|
||||
info,
|
||||
client,
|
||||
unauthenticated_password,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -224,13 +222,13 @@ async fn auth_quirks(
|
||||
}
|
||||
|
||||
async fn authenticate_with_secret(
|
||||
ctx: &mut RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
@@ -253,38 +251,31 @@ async fn authenticate_with_secret(
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
return hacks::authenticate_cleartext(info, client, latency_timer, secret).await;
|
||||
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(info, client, config, latency_timer, secret).await
|
||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||
}
|
||||
|
||||
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
|
||||
/// only if authentication was successfuly.
|
||||
async fn auth_and_wake_compute(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: ClientCredentials,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
||||
let compute_credentials = auth_quirks(
|
||||
api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
let compute_credentials = auth_quirks(ctx, api, creds, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api.wake_compute(extra, &compute_credentials.info).await;
|
||||
let wake_res = api
|
||||
.wake_compute(ctx, extra, &compute_credentials.info)
|
||||
.await;
|
||||
match handle_try_wake(wake_res, num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
@@ -301,6 +292,8 @@ async fn auth_and_wake_compute(
|
||||
tokio::time::sleep(wait_duration).await;
|
||||
};
|
||||
|
||||
ctx.set_project(node.aux.clone());
|
||||
|
||||
match compute_credentials.keys {
|
||||
#[cfg(feature = "testing")]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
@@ -343,11 +336,11 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
use BackendType::*;
|
||||
|
||||
@@ -360,13 +353,13 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
ctx,
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
@@ -380,13 +373,13 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
ctx,
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
(cache_info, BackendType::Postgres(api, user_info))
|
||||
@@ -416,13 +409,13 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
Console(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
Link(_) => Ok(Arc::new(vec![])),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.get_allowed_ips(),
|
||||
@@ -433,14 +426,15 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Console(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Postgres(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, error::UserFacingError,
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
@@ -44,7 +44,6 @@ pub struct ClientCredentials {
|
||||
pub project: Option<SmolStr>,
|
||||
|
||||
pub cache_key: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
@@ -56,16 +55,21 @@ impl ClientCredentials {
|
||||
|
||||
impl ClientCredentials {
|
||||
pub fn parse(
|
||||
ctx: &mut RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<HashSet<String>>,
|
||||
peer_addr: IpAddr,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user = get_param("user")?.into();
|
||||
let user: SmolStr = get_param("user")?.into();
|
||||
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
ctx.set_endpoint_id(sni.map(SmolStr::from));
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_option = params
|
||||
@@ -147,7 +151,6 @@ impl ClientCredentials {
|
||||
user,
|
||||
project,
|
||||
cache_key,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -219,8 +222,8 @@ mod tests {
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
|
||||
@@ -234,8 +237,8 @@ mod tests {
|
||||
("database", "world"), // should be ignored
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
|
||||
@@ -249,8 +252,8 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
assert_eq!(creds.cache_key, "foo");
|
||||
@@ -265,8 +268,8 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
|
||||
@@ -280,8 +283,8 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
|
||||
@@ -298,8 +301,8 @@ mod tests {
|
||||
),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
|
||||
@@ -313,8 +316,8 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
|
||||
@@ -328,8 +331,8 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
|
||||
@@ -342,14 +345,14 @@ mod tests {
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
@@ -363,8 +366,8 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
@@ -382,8 +385,8 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
@@ -402,8 +405,8 @@ mod tests {
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("project"));
|
||||
assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::{net::SocketAddr, sync::Arc};
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::proxy::run_until_cancelled;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
@@ -170,7 +171,16 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
let mut ctx =
|
||||
RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
|
||||
handle_client(
|
||||
&mut ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -236,6 +246,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
ctx: &mut RequestMonitoring,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
@@ -261,5 +272,5 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await
|
||||
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use proxy::console;
|
||||
use proxy::console::provider::AllowedIpsCache;
|
||||
use proxy::console::provider::NodeInfoCache;
|
||||
use proxy::console::provider::RoleSecretCache;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::http;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
@@ -44,6 +45,9 @@ enum AuthBackend {
|
||||
#[derive(Parser)]
|
||||
#[command(version = GIT_VERSION, about)]
|
||||
struct ProxyCliArgs {
|
||||
/// Name of the region this proxy is deployed in
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
region: String,
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:4432")]
|
||||
proxy: String,
|
||||
@@ -133,6 +137,9 @@ struct ProxyCliArgs {
|
||||
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_ip_check_for_http: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Clone, Copy, Debug)]
|
||||
@@ -221,6 +228,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks.spawn(proxy::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
|
||||
@@ -380,6 +392,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit,
|
||||
// TODO: add this argument
|
||||
region: args.region.clone(),
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
|
||||
error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::neon_option,
|
||||
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::neon_option,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
@@ -232,9 +233,9 @@ impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
timeout: Duration,
|
||||
proto: &'static str,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
|
||||
@@ -268,7 +269,9 @@ impl ConnCfg {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE.with_label_values(&[proto]).guard(),
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard(),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -6,7 +6,9 @@ use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
compute,
|
||||
context::RequestMonitoring,
|
||||
scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
@@ -198,10 +200,6 @@ pub mod errors {
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: String,
|
||||
pub options: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
@@ -263,19 +261,20 @@ pub trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
|
||||
@@ -6,8 +6,8 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::console::provider::CachedRoleSecret;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{console::provider::CachedRoleSecret, context::RequestMonitoring};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
@@ -145,7 +145,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
@@ -155,7 +155,7 @@ impl super::Api for Api {
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
|
||||
@@ -164,6 +164,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_extra: &ConsoleReqExtra,
|
||||
_creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
|
||||
@@ -6,8 +6,11 @@ use super::{
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra,
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use crate::{
|
||||
context::RequestMonitoring,
|
||||
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
@@ -49,19 +52,20 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name.as_str()),
|
||||
("application_name", application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
("role", creds.inner.user.as_str()),
|
||||
])
|
||||
@@ -102,19 +106,21 @@ impl Api {
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name.as_str()),
|
||||
("application_name", application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
]);
|
||||
|
||||
@@ -162,7 +168,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let ep = creds.endpoint.clone();
|
||||
@@ -170,7 +176,7 @@ impl super::Api for Api {
|
||||
if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
|
||||
return Ok(role_secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(extra, creds).await?;
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let (_, secret) = self
|
||||
.caches
|
||||
.role_secret
|
||||
@@ -183,7 +189,7 @@ impl super::Api for Api {
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
|
||||
@@ -195,7 +201,7 @@ impl super::Api for Api {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["miss"])
|
||||
.inc();
|
||||
let auth_info = self.do_get_auth_info(extra, creds).await?;
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let ep = creds.endpoint.clone();
|
||||
let user = creds.inner.user.clone();
|
||||
@@ -209,6 +215,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
@@ -236,7 +243,7 @@ impl super::Api for Api {
|
||||
}
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let node = self.do_wake_compute(ctx, extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
|
||||
|
||||
110
proxy/src/context.rs
Normal file
110
proxy/src/context.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
//! Connection request monitoring contexts
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer};
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestMonitoring>> = OnceCell::new();
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Context data for a single request to connect to a database.
|
||||
///
|
||||
/// This data should **not** be used for connection logic, only for observability and limiting purposes.
|
||||
/// All connection logic should instead use strongly typed state machines, not a bunch of Options.
|
||||
pub struct RequestMonitoring {
|
||||
pub peer_addr: IpAddr,
|
||||
pub session_id: Uuid,
|
||||
pub protocol: &'static str,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
|
||||
// filled in as they are discovered
|
||||
project: Option<SmolStr>,
|
||||
branch: Option<SmolStr>,
|
||||
endpoint_id: Option<SmolStr>,
|
||||
user: Option<SmolStr>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
|
||||
// extra
|
||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
|
||||
pub latency_timer: LatencyTimer,
|
||||
}
|
||||
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
protocol: &'static str,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
Self {
|
||||
peer_addr,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
region,
|
||||
|
||||
project: None,
|
||||
branch: None,
|
||||
endpoint_id: None,
|
||||
user: None,
|
||||
application: None,
|
||||
error_kind: None,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
self.application.as_deref().unwrap_or_default(),
|
||||
self.protocol
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_project(&mut self, x: MetricsAuxInfo) {
|
||||
self.branch = Some(x.branch_id);
|
||||
self.endpoint_id = Some(x.endpoint_id);
|
||||
self.project = Some(x.project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
|
||||
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
|
||||
}
|
||||
|
||||
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: SmolStr) {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn log(&mut self) {
|
||||
if let Some(tx) = self.sender.take() {
|
||||
let _: Result<(), _> = tx.send(self.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RequestMonitoring {
|
||||
fn drop(&mut self) {
|
||||
self.log()
|
||||
}
|
||||
}
|
||||
641
proxy/src/context/parquet.rs
Normal file
641
proxy/src/context/parquet.rs
Normal file
@@ -0,0 +1,641 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use futures::{Stream, StreamExt};
|
||||
use parquet::{
|
||||
basic::Compression,
|
||||
file::{
|
||||
metadata::RowGroupMetaDataPtr,
|
||||
properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE},
|
||||
writer::SerializedFileWriter,
|
||||
},
|
||||
record::RecordWriter,
|
||||
};
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig};
|
||||
use tokio::{sync::mpsc, time};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, Span};
|
||||
use utils::backoff;
|
||||
|
||||
use super::{RequestMonitoring, LOG_CHAN};
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
/// Storage location to upload the parquet files to.
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, default_value = "{}", value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_remote_storage: OptRemoteStorageConfig,
|
||||
|
||||
/// How many rows to include in a row group
|
||||
#[clap(long, default_value_t = 8192)]
|
||||
parquet_upload_row_group_size: usize,
|
||||
|
||||
/// How large each column page should be in bytes
|
||||
#[clap(long, default_value_t = DEFAULT_PAGE_SIZE)]
|
||||
parquet_upload_page_size: usize,
|
||||
|
||||
/// How large the total parquet file should be in bytes
|
||||
#[clap(long, default_value_t = 100_000_000)]
|
||||
parquet_upload_size: i64,
|
||||
|
||||
/// How long to wait before forcing a file upload
|
||||
#[clap(long, default_value = "20m", value_parser = humantime::parse_duration)]
|
||||
parquet_upload_maximum_duration: tokio::time::Duration,
|
||||
|
||||
/// What level of compression to use
|
||||
#[clap(long, default_value_t = Compression::UNCOMPRESSED)]
|
||||
parquet_upload_compression: Compression,
|
||||
}
|
||||
|
||||
/// Hack to avoid clap being smarter. If you don't use this type alias, clap assumes more about the optional state and you get
|
||||
/// runtime type errors from the value parser we use.
|
||||
type OptRemoteStorageConfig = Option<RemoteStorageConfig>;
|
||||
|
||||
fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfig> {
|
||||
RemoteStorageConfig::from_toml(&s.parse()?)
|
||||
}
|
||||
|
||||
// Occasional network issues and such can cause remote operations to fail, and
|
||||
// that's expected. If a upload fails, we log it at info-level, and retry.
|
||||
// But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN
|
||||
// level instead, as repeated failures can mean a more serious problem. If it
|
||||
// fails more than FAILED_UPLOAD_RETRIES times, we give up
|
||||
pub(crate) const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3;
|
||||
pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
// the parquet crate leaves a lot to be desired...
|
||||
// what follows is an attempt to write parquet files with minimal allocs.
|
||||
// complication: parquet is a columnar format, while we want to write in as rows.
|
||||
// design:
|
||||
// * we batch up to 1024 rows, then flush them into a 'row group'
|
||||
// * after each rowgroup write, we check the length of the file and upload to s3 if large enough
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
struct RequestData {
|
||||
region: &'static str,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
timestamp: chrono::NaiveDateTime,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: String,
|
||||
username: Option<String>,
|
||||
application_name: Option<String>,
|
||||
endpoint_id: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
error: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl From<RequestMonitoring> for RequestData {
|
||||
fn from(value: RequestMonitoring) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
protocol: value.protocol,
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_str()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parquet request context worker
|
||||
///
|
||||
/// It listened on a channel for all completed requests, extracts the data and writes it into a parquet file,
|
||||
/// then uploads a completed batch to S3
|
||||
pub async fn worker(
|
||||
cancellation_token: CancellationToken,
|
||||
config: ParquetUploadArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
|
||||
tracing::warn!("parquet request upload: no s3 bucket configured");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
LOG_CHAN.set(tx.downgrade()).unwrap();
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
tokio::spawn(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx);
|
||||
});
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let rx = rx.map(RequestData::from);
|
||||
|
||||
let storage =
|
||||
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?;
|
||||
|
||||
let properties = WriterProperties::builder()
|
||||
.set_data_page_size_limit(config.parquet_upload_page_size)
|
||||
.set_compression(config.parquet_upload_compression);
|
||||
|
||||
let parquet_config = ParquetConfig {
|
||||
propeties: Arc::new(properties.build()),
|
||||
rows_per_group: config.parquet_upload_row_group_size,
|
||||
file_size: config.parquet_upload_size,
|
||||
max_duration: config.parquet_upload_maximum_duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
worker_inner(storage, rx, parquet_config).await
|
||||
}
|
||||
|
||||
struct ParquetConfig {
|
||||
propeties: WriterPropertiesPtr,
|
||||
rows_per_group: usize,
|
||||
file_size: i64,
|
||||
|
||||
max_duration: tokio::time::Duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: u64,
|
||||
}
|
||||
|
||||
async fn worker_inner(
|
||||
storage: GenericRemoteStorage,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
config: ParquetConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
let storage = if config.test_remote_failures > 0 {
|
||||
GenericRemoteStorage::unreliable_wrapper(storage, config.test_remote_failures)
|
||||
} else {
|
||||
storage
|
||||
};
|
||||
|
||||
let mut rx = std::pin::pin!(rx);
|
||||
|
||||
let mut rows = Vec::with_capacity(config.rows_per_group);
|
||||
|
||||
let schema = rows.as_slice().schema()?;
|
||||
let file = BytesWriter::default();
|
||||
let mut w = SerializedFileWriter::new(file, schema.clone(), config.propeties.clone())?;
|
||||
|
||||
let mut last_upload = time::Instant::now();
|
||||
|
||||
let mut len = 0;
|
||||
while let Some(row) = rx.next().await {
|
||||
rows.push(row);
|
||||
let force = last_upload.elapsed() > config.max_duration;
|
||||
if rows.len() == config.rows_per_group || force {
|
||||
let rg_meta;
|
||||
(rows, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
if len > config.file_size || force {
|
||||
last_upload = time::Instant::now();
|
||||
let file = upload_parquet(w, len, &storage).await?;
|
||||
w = SerializedFileWriter::new(file, schema.clone(), config.propeties.clone())?;
|
||||
len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if !rows.is_empty() {
|
||||
let rg_meta;
|
||||
(_, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
|
||||
if !w.flushed_row_groups().is_empty() {
|
||||
let _: BytesWriter = upload_parquet(w, len, &storage).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn flush_rows(
|
||||
rows: Vec<RequestData>,
|
||||
mut w: SerializedFileWriter<BytesWriter>,
|
||||
) -> anyhow::Result<(
|
||||
Vec<RequestData>,
|
||||
SerializedFileWriter<BytesWriter>,
|
||||
RowGroupMetaDataPtr,
|
||||
)> {
|
||||
let span = Span::current();
|
||||
let (mut rows, w, rg_meta) = tokio::task::spawn_blocking(move || {
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut rg = w.next_row_group()?;
|
||||
rows.as_slice().write_to_row_group(&mut rg)?;
|
||||
let rg_meta = rg.close()?;
|
||||
|
||||
let size = rg_meta.compressed_size();
|
||||
let compression = rg_meta.compressed_size() as f64 / rg_meta.total_byte_size() as f64;
|
||||
|
||||
debug!(size, compression, "flushed row group to parquet file");
|
||||
|
||||
Ok::<_, parquet::errors::ParquetError>((rows, w, rg_meta))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
rows.clear();
|
||||
Ok((rows, w, rg_meta))
|
||||
}
|
||||
|
||||
async fn upload_parquet(
|
||||
w: SerializedFileWriter<BytesWriter>,
|
||||
len: i64,
|
||||
storage: &GenericRemoteStorage,
|
||||
) -> anyhow::Result<BytesWriter> {
|
||||
let len_uncompressed = w
|
||||
.flushed_row_groups()
|
||||
.iter()
|
||||
.map(|rg| rg.total_byte_size())
|
||||
.sum::<i64>();
|
||||
|
||||
// I don't know how compute intensive this is, although it probably isn't much... better be safe than sorry.
|
||||
// finish method only available on the fork: https://github.com/apache/arrow-rs/issues/5253
|
||||
let (mut file, metadata) = tokio::task::spawn_blocking(move || w.finish())
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
let data = file.buf.split().freeze();
|
||||
|
||||
let compression = len as f64 / len_uncompressed as f64;
|
||||
let size = data.len();
|
||||
let id = uuid::Uuid::now_v7();
|
||||
|
||||
info!(
|
||||
%id,
|
||||
rows = metadata.num_rows,
|
||||
size, compression, "uploading request parquet file"
|
||||
);
|
||||
|
||||
let path = RemotePath::from_string(&format!("requests_{id}.parquet"))?;
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let stream = futures::stream::once(futures::future::ready(Ok(data.clone())));
|
||||
storage.upload(stream, data.len(), &path, None).await
|
||||
},
|
||||
|_e| false,
|
||||
FAILED_UPLOAD_WARN_THRESHOLD,
|
||||
FAILED_UPLOAD_MAX_RETRIES,
|
||||
"request_data_upload",
|
||||
// we don't want cancellation to interrupt here, so we make a dummy cancel token
|
||||
backoff::Cancel::new(CancellationToken::new(), || anyhow::anyhow!("Cancelled")),
|
||||
)
|
||||
.await
|
||||
.context("request_data_upload")?;
|
||||
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
// why doesn't BytesMut impl io::Write?
|
||||
#[derive(Default)]
|
||||
struct BytesWriter {
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
impl std::io::Write for BytesWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.buf.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::Ipv4Addr, num::NonZeroUsize, sync::Arc};
|
||||
|
||||
use camino::Utf8Path;
|
||||
use clap::Parser;
|
||||
use futures::{Stream, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use parquet::{
|
||||
basic::{Compression, ZstdLevel},
|
||||
file::{
|
||||
properties::{WriterProperties, DEFAULT_PAGE_SIZE},
|
||||
reader::FileReader,
|
||||
serialized_reader::SerializedFileReader,
|
||||
},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use remote_storage::{
|
||||
GenericRemoteStorage, RemoteStorageConfig, RemoteStorageKind, S3Config,
|
||||
DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
|
||||
};
|
||||
use tokio::{sync::mpsc, time};
|
||||
|
||||
use super::{worker_inner, ParquetConfig, ParquetUploadArgs, RequestData};
|
||||
|
||||
#[derive(Parser)]
|
||||
struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from(["proxy"]);
|
||||
assert_eq!(parquet_upload.parquet_upload_remote_storage, None);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 8192);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, DEFAULT_PAGE_SIZE);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 100_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(20 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::UNCOMPRESSED
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from([
|
||||
"proxy",
|
||||
"--parquet-upload-remote-storage",
|
||||
"{bucket_name='default',prefix_in_bucket='proxy/',bucket_region='us-east-1',endpoint='http://minio:9000'}",
|
||||
"--parquet-upload-row-group-size",
|
||||
"100",
|
||||
"--parquet-upload-page-size",
|
||||
"10000",
|
||||
"--parquet-upload-size",
|
||||
"10000000",
|
||||
"--parquet-upload-maximum-duration",
|
||||
"10m",
|
||||
"--parquet-upload-compression",
|
||||
"zstd(5)",
|
||||
]);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_remote_storage,
|
||||
Some(RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: "default".into(),
|
||||
bucket_region: "us-east-1".into(),
|
||||
prefix_in_bucket: Some("proxy/".into()),
|
||||
endpoint: Some("http://minio:9000".into()),
|
||||
concurrency_limit: NonZeroUsize::new(
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
|
||||
)
|
||||
.unwrap(),
|
||||
max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
|
||||
})
|
||||
})
|
||||
);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 100);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, 10000);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 10_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(10 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::ZSTD(ZstdLevel::try_new(5).unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
fn generate_request_data(rng: &mut impl Rng) -> RequestData {
|
||||
RequestData {
|
||||
session_id: uuid::Builder::from_random_bytes(rng.gen()).into_uuid(),
|
||||
peer_addr: Ipv4Addr::from(rng.gen::<[u8; 4]>()).to_string(),
|
||||
timestamp: chrono::NaiveDateTime::from_timestamp_millis(
|
||||
rng.gen_range(1703862754..1803862754),
|
||||
)
|
||||
.unwrap(),
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn random_stream(len: usize) -> impl Stream<Item = RequestData> + Unpin {
|
||||
let mut rng = StdRng::from_seed([0x39; 32]);
|
||||
futures::stream::iter(
|
||||
std::iter::repeat_with(move || generate_request_data(&mut rng)).take(len),
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_test(
|
||||
tmpdir: &Utf8Path,
|
||||
config: ParquetConfig,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
) -> Vec<(u64, usize, i64)> {
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::LocalFs(tmpdir.to_path_buf()),
|
||||
};
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config).unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
|
||||
let mut files = std::fs::read_dir(tmpdir.as_std_path())
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.collect_vec();
|
||||
files.sort();
|
||||
|
||||
files
|
||||
.into_iter()
|
||||
.map(|path| std::fs::File::open(tmpdir.as_std_path().join(path)).unwrap())
|
||||
.map(|file| {
|
||||
(
|
||||
file.metadata().unwrap(),
|
||||
SerializedFileReader::new(file).unwrap().metadata().clone(),
|
||||
)
|
||||
})
|
||||
.map(|(file_meta, parquet_meta)| {
|
||||
(
|
||||
file_meta.len(),
|
||||
parquet_meta.num_row_groups(),
|
||||
parquet_meta.file_metadata().num_rows(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_no_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1029153, 3, 6000),
|
||||
(1029075, 3, 6000),
|
||||
(1029216, 3, 6000),
|
||||
(1029129, 3, 6000),
|
||||
(1029250, 3, 6000),
|
||||
(1029017, 3, 6000),
|
||||
(1029175, 3, 6000),
|
||||
(1029247, 3, 6000),
|
||||
(343124, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_min_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(ZstdLevel::default()))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with compression, there are fewer files with more rows per file
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1166201, 6, 12000),
|
||||
(1163577, 6, 12000),
|
||||
(1164641, 6, 12000),
|
||||
(1168772, 6, 12000),
|
||||
(196761, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_strong_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(
|
||||
ZstdLevel::try_new(10).unwrap(),
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with strong compression, the files are smaller
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1144934, 6, 12000),
|
||||
(1144941, 6, 12000),
|
||||
(1144735, 6, 12000),
|
||||
(1144936, 6, 12000),
|
||||
(191035, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_unreliable_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1029153, 3, 6000),
|
||||
(1029075, 3, 6000),
|
||||
(1029216, 3, 6000),
|
||||
(1029129, 3, 6000),
|
||||
(1029250, 3, 6000),
|
||||
(1029017, 3, 6000),
|
||||
(1029175, 3, 6000),
|
||||
(1029247, 3, 6000),
|
||||
(343124, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn verify_parquet_regular_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let mut s = random_stream(3000);
|
||||
while let Some(r) = s.next().await {
|
||||
tx.send(r).unwrap();
|
||||
}
|
||||
time::sleep(time::Duration::from_secs(70)).await
|
||||
}
|
||||
});
|
||||
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(515807, 2, 3001), (515585, 2, 3000), (515425, 2, 2999)],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -28,3 +28,37 @@ pub trait UserFacingError: fmt::Display {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
|
||||
/// Network error between user and proxy. Not necessarily user error
|
||||
Disconnect,
|
||||
|
||||
/// Proxy self-imposed rate limits
|
||||
RateLimit,
|
||||
|
||||
/// internal errors
|
||||
Service,
|
||||
|
||||
/// Error communicating with control plane
|
||||
ControlPlane,
|
||||
|
||||
/// Error communicating with compute
|
||||
Compute,
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "request failed due to user error",
|
||||
ErrorKind::Disconnect => "client disconnected",
|
||||
ErrorKind::RateLimit => "request cancelled due to rate limit",
|
||||
ErrorKind::Service => "internal service error",
|
||||
ErrorKind::ControlPlane => "non-retryable control plane error",
|
||||
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod config;
|
||||
pub mod console;
|
||||
pub mod context;
|
||||
pub mod error;
|
||||
pub mod http;
|
||||
pub mod logging;
|
||||
|
||||
@@ -115,11 +115,12 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
start: Option<time::Instant>,
|
||||
// accumulated time on the stopwatch
|
||||
accumulated: std::time::Duration,
|
||||
pub accumulated: std::time::Duration,
|
||||
// label data
|
||||
protocol: &'static str,
|
||||
cache_miss: bool,
|
||||
@@ -160,7 +161,12 @@ impl LatencyTimer {
|
||||
self.pool_miss = false;
|
||||
}
|
||||
|
||||
pub fn success(mut self) {
|
||||
pub fn success(&mut self) {
|
||||
// stop the stopwatch and record the time that we have accumulated
|
||||
let start = self.start.take().expect("latency timer should be started");
|
||||
self.accumulated += start.elapsed();
|
||||
|
||||
// success
|
||||
self.outcome = "success";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,9 @@ use crate::{
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
context::RequestMonitoring,
|
||||
metrics::{
|
||||
LatencyTimer, NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
|
||||
},
|
||||
protocol2::WithClientIp,
|
||||
@@ -25,7 +26,7 @@ use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
@@ -82,14 +83,16 @@ pub async fn task_main(
|
||||
info!("accepted postgres client connection");
|
||||
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
let mut peer_addr = peer_addr;
|
||||
if let Some(ip) = socket.wait_for_addr().await? {
|
||||
peer_addr = ip;
|
||||
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
|
||||
let mut peer_addr = peer_addr.ip();
|
||||
if let Some(addr) = socket.wait_for_addr().await? {
|
||||
peer_addr = addr.ip();
|
||||
tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
|
||||
} else if config.require_client_ip {
|
||||
bail!("missing required client IP");
|
||||
}
|
||||
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
socket
|
||||
.inner
|
||||
.set_nodelay(true)
|
||||
@@ -97,11 +100,10 @@ pub async fn task_main(
|
||||
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
@@ -134,13 +136,6 @@ pub enum ClientMode {
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
fn protocol_label(&self) -> &'static str {
|
||||
match self {
|
||||
ClientMode::Tcp => "tcp",
|
||||
ClientMode::Websockets { .. } => "ws",
|
||||
}
|
||||
}
|
||||
|
||||
fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
@@ -173,19 +168,18 @@ impl ClientMode {
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
protocol = mode.protocol_label(),
|
||||
protocol = ctx.protocol,
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let proto = mode.protocol_label();
|
||||
let proto = ctx.protocol;
|
||||
let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&[proto])
|
||||
.guard();
|
||||
@@ -195,20 +189,23 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let pause = ctx.latency_timer.pause();
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names, peer_addr))
|
||||
.map(|_| auth::ClientCredentials::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
match result {
|
||||
@@ -217,16 +214,19 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_endpoint_id(creds.get_endpoint());
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
mode.allow_self_signed_compute(config),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
|
||||
.with_session(|session| {
|
||||
client.connect_to_db(ctx, session, mode, &config.authentication_config)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -348,10 +348,13 @@ async fn prepare_client_connection(
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
@@ -397,8 +400,6 @@ struct Client<'a, S> {
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
/// Rate limiter for endpoints
|
||||
@@ -411,7 +412,6 @@ impl<'a, S> Client<'a, S> {
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
@@ -419,7 +419,6 @@ impl<'a, S> Client<'a, S> {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
}
|
||||
@@ -433,6 +432,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
session: cancellation::Session<'_>,
|
||||
mode: ClientMode,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -441,7 +441,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
} = self;
|
||||
@@ -455,27 +454,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
}
|
||||
|
||||
let proto = mode.protocol_label();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
application_name: format!(
|
||||
"{}/{}",
|
||||
params.get("application_name").unwrap_or_default(),
|
||||
proto
|
||||
),
|
||||
options: neon_options(params),
|
||||
};
|
||||
let mut latency_timer = LatencyTimer::new(proto);
|
||||
|
||||
let user = creds.get_user().to_owned();
|
||||
let auth_result = match creds
|
||||
.authenticate(
|
||||
&extra,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
config,
|
||||
&mut latency_timer,
|
||||
)
|
||||
.authenticate(ctx, &extra, &mut stream, mode.allow_cleartext(), config)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
@@ -493,15 +478,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
&TcpMechanism { params, proto },
|
||||
node_info,
|
||||
&extra,
|
||||
&creds,
|
||||
latency_timer,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &extra, &creds)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, session, &mut stream).await?;
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
@@ -510,7 +489,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(stream, node.stream, aux).await
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ use crate::{
|
||||
auth,
|
||||
compute::{self, PostgresConnection},
|
||||
console::{self, errors::WakeComputeError, Api},
|
||||
metrics::{bool_to_str, LatencyTimer, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
||||
context::RequestMonitoring,
|
||||
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
||||
proxy::retry::{retry_after, ShouldRetry},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
@@ -35,15 +36,15 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
proto: &'static str,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect(allow_self_signed_compute, timeout, proto)
|
||||
.connect(ctx, allow_self_signed_compute, timeout)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -54,6 +55,7 @@ pub trait ConnectMechanism {
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
@@ -64,7 +66,6 @@ pub trait ConnectMechanism {
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
pub proto: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -75,10 +76,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
connect_to_compute_once(node_info, timeout, self.proto).await
|
||||
connect_to_compute_once(ctx, node_info, timeout).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -123,11 +125,11 @@ fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
extra: &console::ConsoleReqExtra,
|
||||
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
mut latency_timer: LatencyTimer,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
@@ -136,9 +138,12 @@ where
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
let (config, err) = match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -147,7 +152,7 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
latency_timer.cache_miss();
|
||||
ctx.latency_timer.cache_miss();
|
||||
|
||||
let mut num_retries = 1;
|
||||
|
||||
@@ -155,9 +160,9 @@ where
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let node_info = loop {
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(ctx, extra, creds).await,
|
||||
#[cfg(feature = "testing")]
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(ctx, extra, creds).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
@@ -195,9 +200,12 @@ where
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
loop {
|
||||
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -425,6 +425,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_node_info: &console::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
@@ -491,11 +492,7 @@ fn helper_create_connect_info(
|
||||
auth::BackendType<'_, ComputeUserInfo>,
|
||||
) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: "TEST".into(),
|
||||
options: vec![],
|
||||
};
|
||||
let extra = console::ConsoleReqExtra { options: vec![] };
|
||||
let creds = auth::BackendType::Test(mechanism);
|
||||
(cache, extra, creds)
|
||||
}
|
||||
@@ -503,9 +500,10 @@ fn helper_create_connect_info(
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -514,9 +512,10 @@ async fn connect_to_compute_success() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -526,9 +525,10 @@ async fn connect_to_compute_retry() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -538,9 +538,10 @@ async fn connect_to_compute_non_retry_1() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -551,12 +552,13 @@ async fn connect_to_compute_non_retry_2() {
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -566,9 +568,10 @@ async fn connect_to_compute_non_retry_3() {
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -578,9 +581,10 @@ async fn wake_retry() {
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -17,6 +17,7 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -218,13 +219,14 @@ async fn request_handler(
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
websocket,
|
||||
config,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
host,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
@@ -238,13 +240,14 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
|
||||
sql_over_http::handle(
|
||||
&config.http_config,
|
||||
&mut ctx,
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
session_id,
|
||||
peer_addr,
|
||||
&config.http_config,
|
||||
)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
|
||||
@@ -13,7 +13,7 @@ use pq_proto::StartupMessageParams;
|
||||
use prometheus::{exponential_buckets, register_histogram, Histogram};
|
||||
use rand::Rng;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, net::IpAddr, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
@@ -28,7 +28,8 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
|
||||
console,
|
||||
metrics::{LatencyTimer, NUM_DB_CONNECTIONS_GAUGE},
|
||||
context::RequestMonitoring,
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::{connect_compute::ConnectMechanism, neon_options},
|
||||
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
|
||||
};
|
||||
@@ -309,13 +310,11 @@ impl GlobalConnPool {
|
||||
|
||||
pub async fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
force_new: bool,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Client> {
|
||||
let mut client: Option<ClientInner> = None;
|
||||
let mut latency_timer = LatencyTimer::new("http");
|
||||
|
||||
let mut hash_valid = false;
|
||||
let mut endpoint_pool = Weak::new();
|
||||
@@ -360,23 +359,21 @@ impl GlobalConnPool {
|
||||
info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
ctx,
|
||||
&conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
peer_addr,
|
||||
endpoint_pool.clone(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
client.session.send(session_id)?;
|
||||
client.session.send(ctx.session_id)?;
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
&tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
latency_timer.pool_hit();
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.pool_hit();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(Client::new(client, conn_info, endpoint_pool).await);
|
||||
}
|
||||
} else {
|
||||
@@ -384,11 +381,9 @@ impl GlobalConnPool {
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
ctx,
|
||||
&conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
peer_addr,
|
||||
endpoint_pool.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -483,7 +478,6 @@ impl GlobalConnPool {
|
||||
struct TokioMechanism<'a> {
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
conn_info: &'a ConnInfo,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: uuid::Uuid,
|
||||
idle: Duration,
|
||||
}
|
||||
@@ -496,15 +490,16 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
connect_to_compute_once(
|
||||
ctx,
|
||||
node_info,
|
||||
self.conn_info,
|
||||
timeout,
|
||||
self.conn_id,
|
||||
self.session_id,
|
||||
self.pool.clone(),
|
||||
self.idle,
|
||||
)
|
||||
@@ -520,11 +515,9 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
session_id: uuid::Uuid,
|
||||
latency_timer: LatencyTimer,
|
||||
peer_addr: IpAddr,
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
) -> anyhow::Result<ClientInner> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
@@ -536,12 +529,8 @@ async fn connect_to_compute(
|
||||
("application_name", APP_NAME),
|
||||
("options", conn_info.options.as_deref().unwrap_or("")),
|
||||
]);
|
||||
let creds = auth::ClientCredentials::parse(
|
||||
¶ms,
|
||||
Some(&conn_info.hostname),
|
||||
common_names,
|
||||
peer_addr,
|
||||
)?;
|
||||
let creds =
|
||||
auth::ClientCredentials::parse(ctx, ¶ms, Some(&conn_info.hostname), common_names)?;
|
||||
|
||||
let creds =
|
||||
ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
|
||||
@@ -549,48 +538,48 @@ async fn connect_to_compute(
|
||||
|
||||
let console_options = neon_options(¶ms);
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: APP_NAME.to_string(),
|
||||
options: console_options,
|
||||
};
|
||||
if !config.disable_ip_check_for_http {
|
||||
let allowed_ips = backend.get_allowed_ips(&extra).await?;
|
||||
if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) {
|
||||
let allowed_ips = backend.get_allowed_ips(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed().into());
|
||||
}
|
||||
}
|
||||
let extra = console::ConsoleReqExtra {
|
||||
options: console_options,
|
||||
};
|
||||
let node_info = backend
|
||||
.wake_compute(&extra)
|
||||
.wake_compute(ctx, &extra)
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
ctx.set_project(node_info.aux.clone());
|
||||
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
session_id,
|
||||
pool,
|
||||
idle: config.http_config.pool_options.idle_timeout,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
&backend,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_to_compute_once(
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
conn_id: uuid::Uuid,
|
||||
mut session: uuid::Uuid,
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
idle: Duration,
|
||||
) -> Result<ClientInner, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
let mut session = ctx.session_id;
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
@@ -601,7 +590,7 @@ async fn connect_to_compute_once(
|
||||
.await?;
|
||||
|
||||
let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
@@ -14,6 +13,7 @@ use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
@@ -29,6 +29,7 @@ use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::config::HttpConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
@@ -121,6 +122,7 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
@@ -146,10 +148,11 @@ fn get_conn_info(
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = connection_url.username();
|
||||
let username = SmolStr::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
@@ -176,6 +179,9 @@ fn get_conn_info(
|
||||
}
|
||||
}
|
||||
|
||||
let hostname: SmolStr = hostname.into();
|
||||
ctx.set_endpoint_id(Some(hostname.clone()));
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
@@ -188,9 +194,9 @@ fn get_conn_info(
|
||||
}
|
||||
|
||||
Ok(ConnInfo {
|
||||
username: username.into(),
|
||||
username,
|
||||
dbname: dbname.into(),
|
||||
hostname: hostname.into(),
|
||||
hostname,
|
||||
password: password.into(),
|
||||
options,
|
||||
})
|
||||
@@ -198,23 +204,15 @@ fn get_conn_info(
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
config: &'static HttpConfig,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.request_timeout,
|
||||
handle_inner(
|
||||
config,
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
session_id,
|
||||
peer_addr,
|
||||
),
|
||||
handle_inner(config, ctx, request, sni_hostname, conn_pool),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
@@ -297,11 +295,10 @@ pub async fn handle(
|
||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn handle_inner(
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
@@ -311,7 +308,7 @@ async fn handle_inner(
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
let conn_info = get_conn_info(headers, sni_hostname)?;
|
||||
let conn_info = get_conn_info(ctx, headers, sni_hostname)?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
@@ -340,10 +337,12 @@ async fn handle_inner(
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let paused = ctx.latency_timer.pause();
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
drop(paused);
|
||||
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
@@ -359,9 +358,7 @@ async fn handle_inner(
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
|
||||
let mut client = conn_pool
|
||||
.get(conn_info, !allow_pool, session_id, peer_addr)
|
||||
.await?;
|
||||
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -449,6 +446,7 @@ async fn handle_inner(
|
||||
}
|
||||
};
|
||||
|
||||
ctx.log();
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::io_error,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
@@ -12,7 +13,6 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
net::IpAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
@@ -130,22 +130,20 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
config,
|
||||
ctx,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -52,6 +52,7 @@ num-bigint = { version = "0.4" }
|
||||
num-integer = { version = "0.1", features = ["i128"] }
|
||||
num-traits = { version = "0.2", features = ["i128"] }
|
||||
once_cell = { version = "1" }
|
||||
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs", default-features = false, features = ["zstd"] }
|
||||
prost = { version = "0.11" }
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
regex = { version = "1" }
|
||||
@@ -76,7 +77,7 @@ tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-core = { version = "0.1" }
|
||||
tungstenite = { version = "0.20" }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
uuid = { version = "1", features = ["serde", "v4", "v7"] }
|
||||
zstd = { version = "0.13" }
|
||||
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
|
||||
zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }
|
||||
@@ -85,6 +86,7 @@ zstd-sys = { version = "2", default-features = false, features = ["legacy", "std
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
cc = { version = "1", default-features = false, features = ["parallel"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
|
||||
either = { version = "1" }
|
||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||
itertools = { version = "0.10" }
|
||||
@@ -92,7 +94,11 @@ libc = { version = "0.2", features = ["extra_traits"] }
|
||||
log = { version = "0.4", default-features = false, features = ["std"] }
|
||||
memchr = { version = "2" }
|
||||
nom = { version = "7" }
|
||||
num-bigint = { version = "0.4" }
|
||||
num-integer = { version = "0.1", features = ["i128"] }
|
||||
num-traits = { version = "0.2", features = ["i128"] }
|
||||
once_cell = { version = "1" }
|
||||
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs", default-features = false, features = ["zstd"] }
|
||||
prost = { version = "0.11" }
|
||||
regex = { version = "1" }
|
||||
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
|
||||
@@ -101,5 +107,8 @@ serde = { version = "1", features = ["alloc", "derive"] }
|
||||
syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit"] }
|
||||
syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "full", "visit", "visit-mut"] }
|
||||
time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] }
|
||||
zstd = { version = "0.13" }
|
||||
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
|
||||
zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }
|
||||
|
||||
### END HAKARI SECTION
|
||||
|
||||
Reference in New Issue
Block a user