mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 18:32:56 +00:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8918b1c872 | ||
|
|
fe6946e15e | ||
|
|
0924267612 | ||
|
|
d418cf2dde | ||
|
|
2171a9dc2d | ||
|
|
33a357cfeb | ||
|
|
e06aa4b91d | ||
|
|
34696381c5 | ||
|
|
24c48856a2 | ||
|
|
d698a50984 | ||
|
|
9131d0463d | ||
|
|
214442519f | ||
|
|
c4e868819c | ||
|
|
8198a503f2 | ||
|
|
76371e8452 | ||
|
|
3f66c12280 | ||
|
|
411a80b494 | ||
|
|
cdcb8537f5 | ||
|
|
37221f3252 | ||
|
|
f95ddef4e0 | ||
|
|
ce200a53e8 | ||
|
|
91e8b7d22b | ||
|
|
f47401f2e9 | ||
|
|
469597fdb6 | ||
|
|
2af5352708 | ||
|
|
fbc37acfdf | ||
|
|
b71bf47c33 | ||
|
|
d653d7c62c | ||
|
|
52b73185f9 | ||
|
|
dc41d108e8 | ||
|
|
02e15b7bbb | ||
|
|
864bdf3528 |
172
Cargo.lock
generated
172
Cargo.lock
generated
@@ -919,7 +919,7 @@ version = "0.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
@@ -928,7 +928,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
"shlex",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
@@ -947,9 +947,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.4.1"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
|
||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
@@ -1044,6 +1044,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
@@ -1371,9 +1377,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
|
||||
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@@ -1381,9 +1387,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.4"
|
||||
version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
@@ -1489,7 +1495,7 @@ version = "0.27.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"crossterm_winapi",
|
||||
"libc",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -1675,7 +1681,7 @@ version = "2.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"diesel_derives",
|
||||
@@ -2835,6 +2841,26 @@ version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.26"
|
||||
@@ -3058,7 +3084,7 @@ dependencies = [
|
||||
"measured-derive",
|
||||
"memchr",
|
||||
"parking_lot 0.12.1",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
"ryu",
|
||||
]
|
||||
|
||||
@@ -3225,7 +3251,7 @@ version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"memoffset 0.9.0",
|
||||
@@ -3247,7 +3273,7 @@ version = "6.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"crossbeam-channel",
|
||||
"filetime",
|
||||
"fsevent-sys",
|
||||
@@ -3416,9 +3442,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.18.0"
|
||||
version = "1.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
@@ -4302,7 +4328,7 @@ version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "731e0d9356b0c25f16f33b5be79b1c57b562f141ebfcdb0ad8ac2c13a24293b4"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"chrono",
|
||||
"flate2",
|
||||
"hex",
|
||||
@@ -4317,7 +4343,7 @@ version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"chrono",
|
||||
"hex",
|
||||
]
|
||||
@@ -4455,6 +4481,7 @@ dependencies = [
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
"quinn",
|
||||
"rand 0.8.5",
|
||||
"rand_distr",
|
||||
"rcgen",
|
||||
@@ -4468,7 +4495,7 @@ dependencies = [
|
||||
"routerify",
|
||||
"rsa",
|
||||
"rstest",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
"rustls 0.22.4",
|
||||
"rustls-native-certs 0.7.0",
|
||||
"rustls-pemfile 2.1.1",
|
||||
@@ -4517,6 +4544,55 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.0.0",
|
||||
"rustls 0.23.7",
|
||||
"socket2 0.5.5",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"rand 0.8.5",
|
||||
"ring 0.17.6",
|
||||
"rustc-hash 2.0.0",
|
||||
"rustls 0.23.7",
|
||||
"rustls-platform-verifier",
|
||||
"slab",
|
||||
"thiserror",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2 0.5.5",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.35"
|
||||
@@ -5119,6 +5195,12 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.0"
|
||||
@@ -5143,7 +5225,7 @@ version = "0.38.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"bitflags 2.6.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.13",
|
||||
@@ -5176,6 +5258,20 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring 0.17.6",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.2",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.2"
|
||||
@@ -5226,6 +5322,33 @@ version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls 0.23.7",
|
||||
"rustls-native-certs 0.7.0",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki 0.102.2",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-roots 0.26.1",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.2"
|
||||
@@ -5419,22 +5542,23 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.9.1"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
|
||||
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"bitflags 2.6.0",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"num-bigint",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.9.0"
|
||||
version = "2.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
|
||||
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
|
||||
@@ -112,6 +112,9 @@ ecdsa = "0.16"
|
||||
p256 = "0.13"
|
||||
rsa = "0.9"
|
||||
|
||||
quinn = { version = "0.11", features = [] }
|
||||
rcgen.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -119,7 +122,6 @@ camino-tempfile.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
pbkdf2 = { workspace = true, features = ["simple", "std"] }
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
walkdir.workspace = true
|
||||
|
||||
@@ -4,9 +4,9 @@ pub mod backend;
|
||||
pub use backend::Backend;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ComputeUserInfoMaybeEndpoint;
|
||||
pub(crate) use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
@@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl {
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct AuthError(Box<AuthErrorImpl>);
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
|
||||
@@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
pub(crate) keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum ComputeUserInfoParseError {
|
||||
pub enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
@@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError {
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct ComputeUserInfoMaybeEndpoint {
|
||||
pub(crate) user: RoleName,
|
||||
pub(crate) endpoint_id: Option<EndpointId>,
|
||||
pub(crate) options: NeonOptions,
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: RoleName,
|
||||
pub endpoint_id: Option<EndpointId>,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni(
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub(crate) fn parse(
|
||||
pub fn parse(
|
||||
ctx: &RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
|
||||
230
proxy/src/auth_proxy/backend.rs
Normal file
230
proxy/src/auth_proxy/backend.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
mod classic;
|
||||
mod hacks;
|
||||
|
||||
use tracing::info;
|
||||
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||
};
|
||||
use crate::auth::{self, ComputeUserInfoMaybeEndpoint};
|
||||
use crate::auth_proxy::validate_password_and_exchange;
|
||||
use crate::console::provider::ConsoleBackend;
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::scram;
|
||||
use crate::stream::AuthProxyStreamExt;
|
||||
use crate::{
|
||||
config::AuthenticationConfig,
|
||||
console::{self, provider::CachedNodeInfo, Api},
|
||||
};
|
||||
|
||||
use super::AuthProxyStream;
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
Owned(T),
|
||||
Borrowed(&'a T),
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
MaybeOwned::Owned(t) => t,
|
||||
MaybeOwned::Borrowed(t) => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum Backend<'a, T> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Console(api, ()) => match &**api {
|
||||
ConsoleBackend::Console(endpoint) => {
|
||||
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ConsoleBackend::Postgres(endpoint) => {
|
||||
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Backend<'_, T> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> Backend<'_, &T> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Backend<'a, T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(c, f(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
api: &impl console::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut AuthProxyStream,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
todo!()
|
||||
// let res = hacks::password_hack_no_authentication(info, client).await?;
|
||||
|
||||
// let password = match res.keys {
|
||||
// ComputeCredentialKeys::Password(p) => p,
|
||||
// ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
|
||||
// unreachable!("password hack should return a password")
|
||||
// }
|
||||
// };
|
||||
// (res.info, Some(password))
|
||||
}
|
||||
Ok(info) => info,
|
||||
};
|
||||
|
||||
dbg!("fetching user's authentication info");
|
||||
let cached_secret = api
|
||||
.get_role_secret(&RequestMonitoring::test(), &info)
|
||||
.await?;
|
||||
|
||||
let (cached_entry, secret) = cached_secret.take_value();
|
||||
|
||||
let secret = if let Some(secret) = secret {
|
||||
secret
|
||||
} else {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
dbg!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
|
||||
};
|
||||
|
||||
match authenticate_with_secret(secret, info, client, config).await {
|
||||
Ok(keys) => Ok(keys),
|
||||
Err(e) => {
|
||||
if e.is_auth_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
cached_entry.invalidate();
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_with_secret(
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut AuthProxyStream,
|
||||
// unauthenticated_password: Option<Vec<u8>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// if let Some(password) = unauthenticated_password {
|
||||
// let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
// let auth_outcome =
|
||||
// validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
// let keys = match auth_outcome {
|
||||
// crate::sasl::Outcome::Success(key) => key,
|
||||
// crate::sasl::Outcome::Failure(reason) => {
|
||||
// info!("auth backend failed with an error: {reason}");
|
||||
// return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
// }
|
||||
// };
|
||||
|
||||
// // we have authenticated the password
|
||||
// client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
|
||||
|
||||
// return Ok(ComputeCredentials { info, keys });
|
||||
// }
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
client: &mut AuthProxyStream,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<Backend<'a, ComputeCredentials>> {
|
||||
let res = match self {
|
||||
Self::Console(api, user_info) => {
|
||||
dbg!("authenticating...");
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(&*api, user_info, client, config).await?;
|
||||
Backend::Console(api, credentials)
|
||||
}
|
||||
};
|
||||
|
||||
dbg!("user successfully authenticated");
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::Console(_, creds) => &creds.keys,
|
||||
}
|
||||
}
|
||||
}
|
||||
69
proxy/src/auth_proxy/backend/classic.rs
Normal file
69
proxy/src/auth_proxy/backend/classic.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use super::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeCredentialKeys},
|
||||
auth_proxy::{self, AuthFlow, AuthProxyStream},
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
sasl,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut AuthProxyStream,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
dbg!("auth endpoint chooses SCRAM");
|
||||
let scram = auth_proxy::Scram(&secret);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
error
|
||||
})?.authenticate().await.map_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
error
|
||||
})
|
||||
}
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::AuthError::user_timeout(e)
|
||||
})??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
info: creds,
|
||||
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
scram_keys,
|
||||
)),
|
||||
})
|
||||
}
|
||||
36
proxy/src/auth_proxy/backend/hacks.rs
Normal file
36
proxy/src/auth_proxy/backend/hacks.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use super::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||
};
|
||||
use crate::{
|
||||
auth,
|
||||
auth_proxy::{self, AuthFlow, AuthProxyStream},
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Workaround for clients which don't provide an endpoint (project) name.
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub(crate) async fn password_hack_no_authentication(
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut AuthProxyStream,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth_proxy::PasswordHack)
|
||||
.await?
|
||||
.get_password()
|
||||
.await?;
|
||||
|
||||
info!(project = &*payload.endpoint, "received missing parameter");
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: info.user,
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password(payload.password),
|
||||
})
|
||||
}
|
||||
183
proxy/src/auth_proxy/flow.rs
Normal file
183
proxy/src/auth_proxy/flow.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthProxyStream, PasswordHackPayload};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeCredentialKeys, AuthErrorImpl},
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
stream::AuthProxyStreamExt,
|
||||
};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::task_local;
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub(crate) trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub(crate) struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub(crate) struct Scram<'a>(pub(crate) &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub(crate) struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub(crate) struct AuthFlow<'a, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut AuthProxyStream,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
task_local! {
|
||||
pub(crate) static TLS_SERVER_END_POINT: TlsServerEndPoint;
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a> AuthFlow<'a, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self {
|
||||
let tls_server_end_point = TLS_SERVER_END_POINT.get();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, M>> {
|
||||
dbg!("sending auth begin message");
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthFlow<'_, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn get_password(self) -> auth::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let payload = PasswordHackPayload::parse(password)
|
||||
// If we ended up here and the payload is malformed, it means that
|
||||
// the user neither enabled SNI nor resorted to any other method
|
||||
// for passing the project name we rely on. We should show them
|
||||
// the most helpful error message and point to the documentation.
|
||||
.ok_or(AuthErrorImpl::MissingEndpointName)?;
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl AuthFlow<'_, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn authenticate(self) -> auth::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret) = self.state;
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(auth::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let outcome = sasl::SaslStream2::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
// test only
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_owned(),
|
||||
)))
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
|
||||
};
|
||||
|
||||
let keys = crate::compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: scram_secret.server_key.as_bytes(),
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
|
||||
tokio_postgres::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
17
proxy/src/auth_proxy/mod.rs
Normal file
17
proxy/src/auth_proxy/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::Backend;
|
||||
|
||||
mod password_hack;
|
||||
use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub(crate) use flow::*;
|
||||
use quinn::{RecvStream, SendStream};
|
||||
use tokio::io::Join;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
use crate::PglbCodec;
|
||||
|
||||
pub type AuthProxyStream = Framed<Join<RecvStream, SendStream>, PglbCodec>;
|
||||
121
proxy/src/auth_proxy/password_hack.rs
Normal file
121
proxy/src/auth_proxy/password_hack.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
//! Payload for ad hoc authentication method for clients that don't support SNI.
|
||||
//! See the `impl` for [`super::backend::Backend<ClientCredentials>`].
|
||||
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
|
||||
|
||||
use bstr::ByteSlice;
|
||||
|
||||
use crate::EndpointId;
|
||||
|
||||
pub(crate) struct PasswordHackPayload {
|
||||
pub(crate) endpoint: EndpointId,
|
||||
pub(crate) password: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PasswordHackPayload {
|
||||
pub(crate) fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
|
||||
let separators = [";", "$"];
|
||||
for sep in separators {
|
||||
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
|
||||
let endpoint = endpoint.to_str().ok()?;
|
||||
return Some(Self {
|
||||
endpoint: parse_endpoint_param(endpoint)?.into(),
|
||||
password: password.to_owned(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> {
|
||||
bytes
|
||||
.strip_prefix("project=")
|
||||
.or_else(|| bytes.strip_prefix("endpoint="))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoint_param_fn() {
|
||||
let input = "";
|
||||
assert!(parse_endpoint_param(input).is_none());
|
||||
|
||||
let input = "project=";
|
||||
assert_eq!(parse_endpoint_param(input), Some(""));
|
||||
|
||||
let input = "project=foobar";
|
||||
assert_eq!(parse_endpoint_param(input), Some("foobar"));
|
||||
|
||||
let input = "endpoint=";
|
||||
assert_eq!(parse_endpoint_param(input), Some(""));
|
||||
|
||||
let input = "endpoint=foobar";
|
||||
assert_eq!(parse_endpoint_param(input), Some("foobar"));
|
||||
|
||||
let input = "other_option=foobar";
|
||||
assert!(parse_endpoint_param(input).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_project() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"project=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"project=;";
|
||||
let payload: PasswordHackPayload =
|
||||
PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"project=foobar;pass;word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass;word");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_endpoint() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=;";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"endpoint=foobar;pass;word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass;word");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_dollar() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=$";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"endpoint=foobar$pass$word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass$word");
|
||||
}
|
||||
}
|
||||
242
proxy/src/bin/auth_proxy.rs
Normal file
242
proxy/src/bin/auth_proxy.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use clap::Parser;
|
||||
use proxy::{
|
||||
auth::backend::AuthRateLimiter,
|
||||
auth_proxy::{backend::MaybeOwned, Backend},
|
||||
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
|
||||
console::{
|
||||
caches::ApiCaches,
|
||||
locks::ApiLocks,
|
||||
provider::{neon::Api, ConsoleBackend},
|
||||
},
|
||||
http,
|
||||
metrics::Metrics,
|
||||
proxy::{handle_stream, AuthProxyConfig},
|
||||
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
|
||||
scram::threadpool::ThreadPool,
|
||||
};
|
||||
use quinn::{crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, VarInt};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
select,
|
||||
signal::unix::{signal, SignalKind},
|
||||
time::interval,
|
||||
};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
/// Neon proxy/router
|
||||
#[derive(Parser)]
|
||||
#[command(about)]
|
||||
struct ProxyCliArgs {
|
||||
/// cloud API endpoint for authenticating users
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
wake_compute_cache: String,
|
||||
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
|
||||
wake_compute_lock: String,
|
||||
/// timeout for scram authentication protocol
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
scram_protocol_timeout: tokio::time::Duration,
|
||||
/// size of the threadpool for password hashing
|
||||
#[clap(long, default_value_t = 4)]
|
||||
scram_thread_pool_size: u8,
|
||||
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
|
||||
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_dynamic_rate_limiter: bool,
|
||||
/// Endpoint rate limiter max number of requests per second.
|
||||
///
|
||||
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Whether the auth rate limiter actually takes effect (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
auth_rate_limit_enabled: bool,
|
||||
/// Authentication rate limiter max number of hashes per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
|
||||
auth_rate_limit: Vec<RateBucketInfo>,
|
||||
/// The IP subnet to use when considering whether two IP addresses are considered the same.
|
||||
#[clap(long, default_value_t = 64)]
|
||||
auth_rate_limit_ip_subnet: u8,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
/// cache for `role_secret` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
role_secret_cache: String,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
/// cache for all valid endpoints
|
||||
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
|
||||
endpoint_cache_config: String,
|
||||
|
||||
/// Whether to retry the wake_compute request
|
||||
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
|
||||
wake_compute_retry: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = ProxyCliArgs::parse();
|
||||
|
||||
let server = "127.0.0.1:5634".parse().unwrap();
|
||||
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
|
||||
let crypto = quinn::rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerify))
|
||||
.with_no_client_auth();
|
||||
|
||||
let crypto = QuicClientConfig::try_from(crypto).unwrap();
|
||||
|
||||
let config = quinn::ClientConfig::new(Arc::new(crypto));
|
||||
endpoint.set_default_client_config(config);
|
||||
|
||||
let mut int = signal(SignalKind::interrupt()).unwrap();
|
||||
let mut term = signal(SignalKind::terminate()).unwrap();
|
||||
|
||||
let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap();
|
||||
let mut interval = interval(Duration::from_secs(2));
|
||||
|
||||
let tasks = TaskTracker::new();
|
||||
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
|
||||
let backend = {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap();
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse().unwrap();
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse().unwrap();
|
||||
|
||||
let caches = Box::leak(Box::new(ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse().unwrap();
|
||||
let locks = Box::leak(Box::new(
|
||||
ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)
|
||||
.unwrap(),
|
||||
));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse().unwrap();
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap();
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter);
|
||||
let api = ConsoleBackend::Console(api);
|
||||
Backend::Console(MaybeOwned::Owned(api), ())
|
||||
};
|
||||
|
||||
let auth = AuthenticationConfig {
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
};
|
||||
|
||||
let config = &*Box::leak(Box::new(AuthProxyConfig { backend, auth }));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
_ = int.recv() => break,
|
||||
_ = term.recv() => break,
|
||||
_ = interval.tick() => {
|
||||
let mut stream = conn.open_uni().await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
stream.finish().unwrap();
|
||||
}
|
||||
stream = conn.accept_bi() => {
|
||||
let (send, recv) = stream.unwrap();
|
||||
tasks.spawn(async move {
|
||||
handle_stream(config, send, recv).await.inspect_err(|e| println!("err {e:?}"))
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// graceful shutdown
|
||||
{
|
||||
let mut stream = conn.open_uni().await.unwrap();
|
||||
stream.write_all(b"shutdown").await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
stream.finish().unwrap();
|
||||
}
|
||||
|
||||
tasks.close();
|
||||
tasks.wait().await;
|
||||
conn.close(VarInt::from_u32(1), b"graceful shutdown");
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
struct NoVerify;
|
||||
|
||||
impl danger::ServerCertVerifier for NoVerify {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<danger::ServerCertVerified, quinn::rustls::Error> {
|
||||
Ok(danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &quinn::rustls::DigitallySignedStruct,
|
||||
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
|
||||
Ok(danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &quinn::rustls::DigitallySignedStruct,
|
||||
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
|
||||
Ok(danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
|
||||
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
|
||||
}
|
||||
}
|
||||
653
proxy/src/bin/pglb.rs
Normal file
653
proxy/src/bin/pglb.rs
Normal file
@@ -0,0 +1,653 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use indexmap::IndexMap;
|
||||
use itertools::Itertools;
|
||||
use pq_proto::BeMessage;
|
||||
use proxy::{
|
||||
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
|
||||
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
|
||||
};
|
||||
use quinn::{Connection, Endpoint, RecvStream, SendStream};
|
||||
use rand::Rng;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use tokio::{
|
||||
io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join},
|
||||
net::{TcpListener, TcpStream},
|
||||
select,
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::{error, warn};
|
||||
|
||||
type AuthConnId = usize;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AuthConnState {
|
||||
conns: Mutex<IndexMap<AuthConnId, AuthConn>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AuthConn {
|
||||
conn: Connection,
|
||||
// latency info...
|
||||
}
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
|
||||
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?;
|
||||
|
||||
let auth_connections = Arc::new(AuthConnState {
|
||||
conns: Mutex::new(IndexMap::new()),
|
||||
});
|
||||
|
||||
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
|
||||
|
||||
let frontend_config = frontent_tls_config()?;
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend(
|
||||
"0.0.0.0:5432".parse()?,
|
||||
frontend_config,
|
||||
auth_connections.clone(),
|
||||
));
|
||||
|
||||
quinn_handle.await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn endpoint_config(addr: SocketAddr) -> Result<Endpoint> {
|
||||
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, "pglb");
|
||||
let key = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).context("keygen")?;
|
||||
params.key_pair = Some(key);
|
||||
|
||||
let cert = rcgen::Certificate::from_params(params).context("cert")?;
|
||||
let cert_der = cert.serialize_der().context("serialize")?;
|
||||
let key_der = cert.serialize_private_key_der();
|
||||
let cert = CertificateDer::from(cert_der);
|
||||
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der));
|
||||
|
||||
let config = quinn::ServerConfig::with_single_cert(vec![cert], key).context("server config")?;
|
||||
Endpoint::server(config, addr).context("endpoint")
|
||||
}
|
||||
|
||||
async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
|
||||
loop {
|
||||
let incoming = ep.accept().await.expect("quinn server should not crash");
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let conn = incoming.await.unwrap();
|
||||
|
||||
let conn_id = conn.stable_id();
|
||||
println!("[{conn_id:?}] new conn");
|
||||
|
||||
state
|
||||
.conns
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(conn_id, AuthConn { conn: conn.clone() });
|
||||
|
||||
// heartbeat loop
|
||||
loop {
|
||||
match timeout(Duration::from_secs(10), conn.accept_uni()).await {
|
||||
Ok(Ok(mut heartbeat_stream)) => {
|
||||
let data = heartbeat_stream.read_to_end(128).await.unwrap();
|
||||
if data.starts_with(b"shutdown") {
|
||||
println!("[{conn_id:?}] conn shutdown");
|
||||
break;
|
||||
}
|
||||
// else update latency info
|
||||
}
|
||||
Ok(Err(conn_err)) => {
|
||||
println!("[{conn_id:?}] conn err {conn_err:?}");
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
println!("[{conn_id:?}] conn timeout err");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state.conns.lock().unwrap().remove(&conn_id);
|
||||
let conn_closed = conn.closed().await;
|
||||
println!("[{conn_id:?}] conn closed {conn_closed:?}");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn frontent_tls_config() -> Result<TlsConfig> {
|
||||
let (cert, key) = (
|
||||
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
|
||||
.collect_vec()
|
||||
.remove(0)
|
||||
.unwrap(),
|
||||
PrivateKeyDer::Pkcs8(
|
||||
rustls_pemfile::pkcs8_private_keys(&mut &*std::fs::read("proxy.key").unwrap())
|
||||
.collect_vec()
|
||||
.remove(0)
|
||||
.unwrap(),
|
||||
),
|
||||
);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone_key())?
|
||||
.into();
|
||||
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_frontend(
|
||||
addr: SocketAddr,
|
||||
tls: TlsConfig,
|
||||
state: Arc<AuthConnState>,
|
||||
) -> Result<Infallible> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
println!("starting");
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, client_addr)) => {
|
||||
println!("accepted");
|
||||
let conn = PglbConn::new(&state, &tls)?;
|
||||
connections.spawn(conn.handle(socket, client_addr));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("connection accept error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PglbConn<S: PglbConnState> {
|
||||
inner: PglbConnInner,
|
||||
state: S,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PglbConnInner {
|
||||
tls_config: TlsConfig,
|
||||
auth_conns: Arc<AuthConnState>,
|
||||
}
|
||||
|
||||
trait PglbConnState: std::fmt::Debug {}
|
||||
impl PglbConnState for Start {}
|
||||
impl PglbConnState for ClientConnect {}
|
||||
impl PglbConnState for AuthPassthrough {}
|
||||
impl PglbConnState for ComputeConnect {}
|
||||
impl PglbConnState for ComputePassthrough {}
|
||||
impl PglbConnState for End {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Start;
|
||||
|
||||
impl PglbConn<Start> {
|
||||
fn new(auth_conns: &Arc<AuthConnState>, tls_config: &TlsConfig) -> Result<Self> {
|
||||
Ok(PglbConn {
|
||||
inner: PglbConnInner {
|
||||
auth_conns: Arc::clone(auth_conns),
|
||||
tls_config: tls_config.clone(),
|
||||
},
|
||||
state: Start,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
self,
|
||||
client_stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
) -> Result<PglbConn<End>> {
|
||||
self.handle_start(client_stream, client_addr)
|
||||
.await?
|
||||
.handle_client_connect()
|
||||
.await?
|
||||
.handle_auth_passthrough()
|
||||
.await?
|
||||
.handle_compute_connect()
|
||||
.await?
|
||||
.handle_compute_passthrough()
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl PglbConn<Start> {
|
||||
async fn handle_start(
|
||||
self,
|
||||
mut client_stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
) -> Result<PglbConn<ClientConnect>> {
|
||||
match client_stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
bail!("socket option error: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: HAProxy protocol
|
||||
|
||||
let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await {
|
||||
Ok(tls_requested) => tls_requested,
|
||||
Err(e) => {
|
||||
bail!("check_for_ssl_request: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
let (client_stream, payload) = if tls_requested {
|
||||
println!("starting tls upgrade");
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
BeMessage::write(&mut buf, &BeMessage::EncryptionResponse(true)).unwrap();
|
||||
client_stream.write_all(&buf).await?;
|
||||
|
||||
let (stream, tls_server_end_point, server_name) =
|
||||
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
|
||||
Ok((stream, ep, sn)) => (stream, ep, sn),
|
||||
Err(e) => {
|
||||
bail!("tls_upgrade: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
stream,
|
||||
ConnectionInitiatedPayload {
|
||||
tls_server_end_point,
|
||||
server_name,
|
||||
ip_addr: client_addr.ip(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// TODO: support unsecured connections?
|
||||
bail!("closing non-TLS connection");
|
||||
};
|
||||
|
||||
println!("tls done");
|
||||
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: ClientConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
|
||||
println!("checking for ssl request");
|
||||
let mut buf = vec![0u8; 8];
|
||||
|
||||
let n_peek = stream.peek(&mut buf).await?;
|
||||
if n_peek == 0 {
|
||||
bail!("EOF");
|
||||
}
|
||||
|
||||
assert_eq!(buf.len(), 8); // TODO: loop, read more
|
||||
|
||||
if buf.len() != 8
|
||||
|| buf[0..4] != 8u32.to_be_bytes()
|
||||
|| buf[4..8] != 80877103u32.to_be_bytes()
|
||||
{
|
||||
return Ok(false);
|
||||
}
|
||||
stream.read_exact(&mut buf).await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn tls_upgrade(
|
||||
stream: TcpStream,
|
||||
tls: TlsConfig,
|
||||
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
|
||||
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
|
||||
.accept(stream)
|
||||
.await?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
let server_name = conn_info.server_name().map(|s| s.to_string());
|
||||
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
bail!("protocol violation");
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(server_name.as_deref())
|
||||
.ok_or(anyhow!("missing cert"))?;
|
||||
|
||||
Ok((tls_stream, tls_server_end_point, server_name))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClientConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
}
|
||||
|
||||
impl PglbConn<ClientConnect> {
|
||||
async fn handle_client_connect(self) -> Result<PglbConn<AuthPassthrough>> {
|
||||
let auth_conn = {
|
||||
let conns = self.inner.auth_conns.conns.lock().unwrap();
|
||||
if conns.is_empty() {
|
||||
bail!("no auth proxies avaiable");
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
conns
|
||||
.get_index(rng.gen_range(0..conns.len()))
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
|
||||
// TODO: check closed?
|
||||
};
|
||||
println!("connecting to {}", auth_conn.conn.stable_id());
|
||||
|
||||
let (send, recv) = auth_conn.conn.open_bi().await?;
|
||||
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
|
||||
|
||||
auth_stream
|
||||
.send(proxy::PglbMessage::Control(
|
||||
proxy::PglbControlMessage::ConnectionInitiated(self.state.payload),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: AuthPassthrough {
|
||||
client_stream: Framed::new(
|
||||
self.state.client_stream,
|
||||
PgRawCodec {
|
||||
start_or_ssl_request: true,
|
||||
},
|
||||
),
|
||||
auth_stream,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AuthPassthrough {
|
||||
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
}
|
||||
|
||||
impl PglbConn<AuthPassthrough> {
|
||||
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
|
||||
let mut client_stream = self.state.client_stream;
|
||||
let mut auth_stream = self.state.auth_stream;
|
||||
|
||||
loop {
|
||||
select! {
|
||||
biased;
|
||||
|
||||
msg = auth_stream.next() => {
|
||||
match msg.context("auth proxy disconnected")?? {
|
||||
PglbMessage::Postgres(payload) => {
|
||||
println!("msg {payload:?}");
|
||||
client_stream.send(PgRawMessage::Generic { payload }).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
|
||||
println!("socket");
|
||||
return Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: ComputeConnect {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
compute_socket:socket,
|
||||
},
|
||||
});
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg = client_stream.next() => {
|
||||
match msg.context("client disconnected")?? {
|
||||
PgRawMessage::SslRequest => bail!("protocol violation"),
|
||||
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
|
||||
auth_stream.send(proxy::PglbMessage::Postgres(
|
||||
payload
|
||||
)).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ComputeConnect {
|
||||
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
compute_socket: SocketAddr,
|
||||
}
|
||||
|
||||
impl PglbConn<ComputeConnect> {
|
||||
async fn handle_compute_connect(self) -> Result<PglbConn<ComputePassthrough>> {
|
||||
let ComputeConnect {
|
||||
client_stream,
|
||||
mut auth_stream,
|
||||
compute_socket,
|
||||
} = self.state;
|
||||
let compute_stream = TcpStream::connect(compute_socket).await?;
|
||||
compute_stream
|
||||
.set_nodelay(true)
|
||||
.context("socket option error")?;
|
||||
|
||||
let mut compute_stream = Framed::new(
|
||||
compute_stream,
|
||||
PgRawCodec {
|
||||
start_or_ssl_request: false,
|
||||
},
|
||||
);
|
||||
|
||||
let mut resps = 4;
|
||||
loop {
|
||||
select! {
|
||||
msg = auth_stream.next() => {
|
||||
match msg.context("auth proxy disconnected")?? {
|
||||
PglbMessage::Postgres(payload) => {
|
||||
println!("msg {payload:?}");
|
||||
compute_stream.send(PgRawMessage::Generic { payload } ).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
|
||||
println!("establish");
|
||||
return Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: ComputePassthrough {
|
||||
client_stream,
|
||||
compute_stream,
|
||||
},
|
||||
});
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) |
|
||||
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg = compute_stream.next(), if resps > 0 => {
|
||||
match msg.context("compute disconnected")?? {
|
||||
PgRawMessage::SslRequest => bail!("protocol violation"),
|
||||
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
|
||||
resps -= 1;
|
||||
auth_stream.send(proxy::PglbMessage::Postgres(
|
||||
payload
|
||||
)).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ComputePassthrough {
|
||||
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
|
||||
compute_stream: Framed<TcpStream, PgRawCodec>,
|
||||
}
|
||||
|
||||
impl PglbConn<ComputePassthrough> {
|
||||
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
|
||||
let ComputePassthrough {
|
||||
client_stream,
|
||||
compute_stream,
|
||||
} = self.state;
|
||||
|
||||
let mut client_parts = client_stream.into_parts();
|
||||
let mut compute_parts = compute_stream.into_parts();
|
||||
|
||||
assert!(compute_parts.write_buf.is_empty());
|
||||
assert!(client_parts.write_buf.is_empty());
|
||||
|
||||
client_parts.io.write_all(&compute_parts.read_buf).await?;
|
||||
compute_parts.io.write_all(&client_parts.read_buf).await?;
|
||||
|
||||
drop(client_parts.read_buf);
|
||||
drop(client_parts.write_buf);
|
||||
drop(compute_parts.read_buf);
|
||||
drop(compute_parts.write_buf);
|
||||
|
||||
copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?;
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: End,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct End;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PgRawCodec {
|
||||
start_or_ssl_request: bool,
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Encoder<PgRawMessage> for PgRawCodec {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
|
||||
item.encode(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for PgRawCodec {
|
||||
type Item = PgRawMessage;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if self.start_or_ssl_request {
|
||||
match PgRawMessage::decode(src, true)? {
|
||||
msg @ Some(PgRawMessage::Start(..)) => {
|
||||
self.start_or_ssl_request = false;
|
||||
Ok(msg)
|
||||
}
|
||||
msg @ Some(PgRawMessage::SslRequest) => Ok(msg),
|
||||
Some(PgRawMessage::Generic { .. }) => unreachable!(),
|
||||
None => Ok(None),
|
||||
}
|
||||
} else {
|
||||
match PgRawMessage::decode(src, false)? {
|
||||
Some(PgRawMessage::Start(..)) => unreachable!(),
|
||||
Some(PgRawMessage::SslRequest) => unreachable!(),
|
||||
msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PgRawMessage {
|
||||
SslRequest,
|
||||
Start(BytesMut),
|
||||
Generic { payload: BytesMut },
|
||||
}
|
||||
|
||||
impl PgRawMessage {
|
||||
fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> {
|
||||
match self {
|
||||
Self::SslRequest => {
|
||||
dst.put_u32(8);
|
||||
dst.put_u32(80877103);
|
||||
}
|
||||
Self::Start(payload) | Self::Generic { payload } => {
|
||||
dst.put_slice(payload);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
|
||||
let extra = if start { 0 } else { 1 };
|
||||
|
||||
if src.remaining() < 4 + extra {
|
||||
src.reserve(4 + extra);
|
||||
return Ok(None);
|
||||
}
|
||||
let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
|
||||
Ok(Some(PgRawMessage::SslRequest))
|
||||
} else if start {
|
||||
Ok(Some(PgRawMessage::Start(src.split_to(length))))
|
||||
} else {
|
||||
Ok(Some(PgRawMessage::Generic {
|
||||
payload: src.split_to(length),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -149,7 +150,7 @@ pub fn configure_tls(
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
|
||||
@@ -125,7 +125,7 @@ impl RequestMonitoring {
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// #[cfg(test)]
|
||||
pub(crate) fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
|
||||
}
|
||||
|
||||
139
proxy/src/lib.rs
139
proxy/src/lib.rs
@@ -82,15 +82,23 @@
|
||||
impl_trait_overcaptures,
|
||||
)]
|
||||
|
||||
use std::{convert::Infallible, future::Future};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::{Buf, BufMut};
|
||||
use config::TlsServerEndPoint;
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
pub mod auth;
|
||||
pub mod auth_proxy;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
pub mod compute;
|
||||
@@ -274,3 +282,132 @@ impl EndpointId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PglbCodec;
|
||||
|
||||
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
PglbMessage::Control(ctrl) => {
|
||||
dst.put_u8(1);
|
||||
match ctrl {
|
||||
PglbControlMessage::ConnectionInitiated(msg) => {
|
||||
let encode = serde_json::to_string(&msg).context("ser")?;
|
||||
dst.put_u32(1 + encode.len() as u32);
|
||||
dst.put_u8(0);
|
||||
dst.put(encode.as_bytes());
|
||||
}
|
||||
PglbControlMessage::ConnectToCompute { socket } => match socket {
|
||||
SocketAddr::V4(v4) => {
|
||||
dst.put_u32(1 + 4 + 2);
|
||||
dst.put_u8(1);
|
||||
dst.put_u32(v4.ip().to_bits());
|
||||
dst.put_u16(v4.port());
|
||||
}
|
||||
SocketAddr::V6(v6) => {
|
||||
dst.put_u32(1 + 16 + 2);
|
||||
dst.put_u8(1);
|
||||
dst.put_u128(v6.ip().to_bits());
|
||||
dst.put_u16(v6.port());
|
||||
}
|
||||
},
|
||||
PglbControlMessage::ComputeEstablish => {
|
||||
dst.put_u32(1);
|
||||
dst.put_u8(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
PglbMessage::Postgres(pg) => {
|
||||
dst.put_u8(0);
|
||||
dst.put_u32(pg.len() as u32);
|
||||
dst.put(pg);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for PglbCodec {
|
||||
type Item = PglbMessage;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if dst.remaining() < 5 {
|
||||
dst.reserve(5);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let msg = dst[0];
|
||||
let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize;
|
||||
|
||||
if len + 5 > dst.remaining() {
|
||||
dst.reserve(len + 5);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
dst.advance(5);
|
||||
let mut payload = dst.split_to(len);
|
||||
|
||||
match msg {
|
||||
// postgres
|
||||
0 => Ok(Some(PglbMessage::Postgres(payload))),
|
||||
// control
|
||||
1 => {
|
||||
if payload.is_empty() {
|
||||
bail!("invalid ctrl message")
|
||||
}
|
||||
let ctrl_msg = payload.split_to(1)[0];
|
||||
let ctrl_msg = match ctrl_msg {
|
||||
0 => PglbControlMessage::ConnectionInitiated(
|
||||
serde_json::from_slice(&payload).context("deser")?,
|
||||
),
|
||||
// ipv4 socket
|
||||
1 if len == 7 => PglbControlMessage::ConnectToCompute {
|
||||
socket: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())),
|
||||
payload.get_u16(),
|
||||
),
|
||||
},
|
||||
|
||||
// ipv6 socket
|
||||
1 if len == 19 => PglbControlMessage::ConnectToCompute {
|
||||
socket: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())),
|
||||
payload.get_u16(),
|
||||
),
|
||||
},
|
||||
|
||||
2 if len == 1 => PglbControlMessage::ComputeEstablish,
|
||||
|
||||
_ => bail!("invalid ctrl message"),
|
||||
};
|
||||
Ok(Some(PglbMessage::Control(ctrl_msg)))
|
||||
}
|
||||
_ => bail!("invalid message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum PglbMessage {
|
||||
Control(PglbControlMessage),
|
||||
Postgres(bytes::BytesMut),
|
||||
}
|
||||
|
||||
pub enum PglbControlMessage {
|
||||
// from pglb to auth proxy
|
||||
ConnectionInitiated(ConnectionInitiatedPayload),
|
||||
// from auth proxy to pglb
|
||||
ConnectToCompute { socket: SocketAddr },
|
||||
// from auth proxy to pglb
|
||||
ComputeEstablish,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ConnectionInitiatedPayload {
|
||||
pub tls_server_end_point: TlsServerEndPoint,
|
||||
pub server_name: Option<String>,
|
||||
pub ip_addr: IpAddr,
|
||||
}
|
||||
|
||||
@@ -7,9 +7,35 @@ pub(crate) mod handshake;
|
||||
pub(crate) mod passthrough;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use connect_compute::ComputeConnectBackend;
|
||||
pub use copy_bidirectional::copy_bidirectional_client_compute;
|
||||
pub use copy_bidirectional::ErrorSource;
|
||||
use futures::SinkExt;
|
||||
use futures::TryStreamExt;
|
||||
use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::ChannelBinding;
|
||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use quinn::RecvStream;
|
||||
use quinn::SendStream;
|
||||
use tokio::io::join;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth_proxy::AuthProxyStream;
|
||||
use crate::auth_proxy::TLS_SERVER_END_POINT;
|
||||
use crate::console::NodeInfo;
|
||||
use crate::stream::AuthProxyStreamExt;
|
||||
use crate::ConnectionInitiatedPayload;
|
||||
use crate::PglbControlMessage;
|
||||
use crate::PglbMessage;
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
|
||||
@@ -30,6 +56,8 @@ use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
@@ -377,10 +405,10 @@ async fn prepare_client_connection<P>(
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
@@ -431,3 +459,157 @@ pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
let (_, [k, v]) = cap.extract();
|
||||
Some((k, v))
|
||||
}
|
||||
|
||||
pub struct AuthProxyConfig {
|
||||
pub backend: crate::auth_proxy::Backend<'static, ()>,
|
||||
pub auth: crate::config::AuthenticationConfig,
|
||||
}
|
||||
|
||||
pub async fn handle_stream(
|
||||
config: &'static AuthProxyConfig,
|
||||
send: SendStream,
|
||||
recv: RecvStream,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
|
||||
|
||||
// recv connection metadata
|
||||
let first_msg = stream.try_next().await?;
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg
|
||||
else {
|
||||
panic!("invalid first msg")
|
||||
};
|
||||
|
||||
println!("new conn: {conn_info:?}");
|
||||
|
||||
// read startup packet
|
||||
let startup = stream.read_startup_packet().await?;
|
||||
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
|
||||
panic!("invalid startup message")
|
||||
};
|
||||
|
||||
println!("params: {params:?}");
|
||||
|
||||
let user_info = auth_with_user(&mut stream, config, &conn_info, ¶ms).await?;
|
||||
|
||||
println!("authenticated");
|
||||
|
||||
// wake the compute
|
||||
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
|
||||
|
||||
println!("woke compute");
|
||||
|
||||
let addr: IpAddr = node_info.config.get_host()?.parse()?;
|
||||
let socket = SocketAddr::new(addr, node_info.config.get_ports()[0]);
|
||||
|
||||
// tell pglb that the compute is up
|
||||
stream
|
||||
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
|
||||
socket,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
// send startup message to compute
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(params.iter(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
auth_with_compute(&mut stream, user_info.get_keys()).await?;
|
||||
|
||||
stream
|
||||
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn auth_with_user(
|
||||
stream: &mut AuthProxyStream,
|
||||
config: &'static AuthProxyConfig,
|
||||
conn_info: &ConnectionInitiatedPayload,
|
||||
params: &StartupMessageParams,
|
||||
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
|
||||
dbg!("auth...");
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = auth::ComputeUserInfoMaybeEndpoint {
|
||||
user: params.get("user").context("missing user")?.into(),
|
||||
endpoint_id: conn_info
|
||||
.server_name
|
||||
.as_deref()
|
||||
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
|
||||
options: NeonOptions::parse_params(params),
|
||||
};
|
||||
|
||||
dbg!("parsed used info");
|
||||
|
||||
// authenticate the user
|
||||
let user_info = config.backend.as_ref().map(|()| user_info);
|
||||
let res = TLS_SERVER_END_POINT
|
||||
.scope(
|
||||
conn_info.tls_server_end_point,
|
||||
user_info.authenticate(stream, &config.auth),
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_info = match res {
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(user_info)
|
||||
}
|
||||
|
||||
async fn auth_with_compute(
|
||||
stream: &mut AuthProxyStream,
|
||||
keys: &ComputeCredentialKeys,
|
||||
) -> anyhow::Result<()> {
|
||||
let ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(scram_keys)) = keys else {
|
||||
bail!("missing keys");
|
||||
};
|
||||
|
||||
// compute offers sasl
|
||||
stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSasl(_body) => Ok(()),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// send auth message
|
||||
let mut scram = ScramSha256::new_with_keys(*scram_keys, ChannelBinding::unsupported());
|
||||
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let cont_body = stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSaslContinue(body) => Ok(body),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
scram.update(cont_body.data()).await?;
|
||||
|
||||
frontend::sasl_response(scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let final_body = stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSaslFinal(body) => Ok(body),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
scram.finish(final_body.data())?;
|
||||
|
||||
// wait for ok.
|
||||
stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationOk => Ok(()),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
mod stream2;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use std::io;
|
||||
@@ -17,6 +18,7 @@ use thiserror::Error;
|
||||
pub(crate) use channel_binding::ChannelBinding;
|
||||
pub(crate) use messages::FirstMessage;
|
||||
pub(crate) use stream::{Outcome, SaslStream};
|
||||
pub(crate) use stream2::SaslStream2;
|
||||
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
|
||||
85
proxy/src/sasl/stream2.rs
Normal file
85
proxy/src/sasl/stream2.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Abstraction for the string-oriented SASL protocols.
|
||||
|
||||
use crate::{
|
||||
auth_proxy::AuthProxyStream,
|
||||
sasl::{messages::ServerMessage, Mechanism},
|
||||
stream::AuthProxyStreamExt,
|
||||
};
|
||||
use std::io;
|
||||
use tracing::info;
|
||||
|
||||
use super::Outcome;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub(crate) struct SaslStream2<'a> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut AuthProxyStream,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> SaslStream2<'a> {
|
||||
pub(crate) fn new(stream: &'a mut AuthProxyStream, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslStream2<'_> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslStream2<'_> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslStream2<'_> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub(crate) async fn authenticate<M: Mechanism>(
|
||||
mut self,
|
||||
mut mechanism: M,
|
||||
) -> crate::sasl::Result<Outcome<M::Output>> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let step = mechanism.exchange(input).map_err(|error| {
|
||||
info!(?error, "error during SASL exchange");
|
||||
error
|
||||
})?;
|
||||
|
||||
use crate::sasl::Step;
|
||||
return Ok(match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved_mechanism;
|
||||
continue;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
self.send(&ServerMessage::Final(&reply)).await?;
|
||||
Outcome::Success(result)
|
||||
}
|
||||
Step::Failure(reason) => Outcome::Failure(reason),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,13 @@
|
||||
use crate::auth_proxy::AuthProxyStream;
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::PglbMessage;
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::BytesMut;
|
||||
|
||||
use futures::{SinkExt, TryStreamExt};
|
||||
use postgres_protocol::message::backend;
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
@@ -294,3 +299,140 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub(crate) trait AuthProxyStreamExt {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>;
|
||||
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket>;
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage>;
|
||||
|
||||
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
|
||||
|
||||
async fn read_backend_message<T>(
|
||||
&mut self,
|
||||
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
|
||||
) -> anyhow::Result<T>;
|
||||
}
|
||||
|
||||
impl AuthProxyStreamExt for AuthProxyStream {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
let mut b = BytesMut::new();
|
||||
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.start_send_unpin(PglbMessage::Postgres(b))
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush()
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
self.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await
|
||||
.inspect_err(|e| debug!("write_message failed: {e}"))
|
||||
.ok();
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
let msg = self
|
||||
.try_next()
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
|
||||
.ok_or_else(err_connection)?;
|
||||
|
||||
match msg {
|
||||
PglbMessage::Control(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"unexpected control message",
|
||||
)),
|
||||
PglbMessage::Postgres(pg) => {
|
||||
let mut buf = BytesMut::from(&*pg);
|
||||
FeStartupPacket::parse(&mut buf)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
let msg = self
|
||||
.try_next()
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
|
||||
.ok_or_else(err_connection)?;
|
||||
|
||||
match msg {
|
||||
PglbMessage::Control(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"unexpected control message",
|
||||
)),
|
||||
PglbMessage::Postgres(pg) => {
|
||||
let mut buf = BytesMut::from(&*pg);
|
||||
FeMessage::parse(&mut buf)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
match self.read_message().await? {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("unexpected message type: {bad:?}"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_backend_message<T>(
|
||||
&mut self,
|
||||
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
|
||||
) -> anyhow::Result<T> {
|
||||
let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
let message = backend::Message::parse(&mut buf)?.context("missing")?;
|
||||
f(message)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user