Compare commits

...

32 Commits

Author SHA1 Message Date
Conrad Ludgate
8918b1c872 some cleanup 2024-09-13 18:27:07 +01:00
Conrad Ludgate
fe6946e15e hack around race condition 2024-09-13 10:33:32 +01:00
Conrad Ludgate
0924267612 fixup 2024-09-13 10:12:30 +01:00
Folke Behrens
d418cf2dde Merge branch 'cloneable/pglb-compute-passthrough' into pglb 2024-09-13 10:05:06 +01:00
Folke Behrens
2171a9dc2d Add client-compute passthrough 2024-09-13 10:04:20 +01:00
Folke Behrens
33a357cfeb Merge branch 'cloneable/pglb-compute-connect' into pglb 2024-09-13 09:54:20 +01:00
Folke Behrens
e06aa4b91d impl compute connect 2024-09-13 09:53:55 +01:00
Conrad Ludgate
34696381c5 fix pg raw decoding 2024-09-13 09:40:51 +01:00
Conrad Ludgate
24c48856a2 update ssl handling and add some logs 2024-09-13 09:25:11 +01:00
Conrad Ludgate
d698a50984 abstract out auth 2024-09-13 08:44:07 +01:00
Folke Behrens
9131d0463d Merge branch 'cloneable/pglb-msg-codec' into pglb 2024-09-12 23:49:22 +01:00
Folke Behrens
214442519f Add pglb/auth passthrough 2024-09-12 23:47:55 +01:00
Folke Behrens
c4e868819c Merge branch 'cloneable/pglb-type-state-pattern' into pglb 2024-09-12 21:34:06 +01:00
Folke Behrens
8198a503f2 refactor to type state pattern 2024-09-12 21:32:47 +01:00
Conrad Ludgate
76371e8452 add auth handshake to compute 2024-09-12 21:16:27 +01:00
Folke Behrens
3f66c12280 Merge branch 'cloneable/pglb-workers' into pglb 2024-09-12 18:00:37 +01:00
Folke Behrens
411a80b494 Add worker state machine 2024-09-12 17:59:54 +01:00
Conrad Ludgate
cdcb8537f5 delete dead code 2024-09-12 17:57:59 +01:00
Conrad Ludgate
37221f3252 properly handle tls-server-end-point 2024-09-12 17:55:27 +01:00
Conrad Ludgate
f95ddef4e0 call wake compute 2024-09-12 17:42:21 +01:00
Conrad Ludgate
ce200a53e8 build out auth proxy core logic 2024-09-12 17:23:56 +01:00
Conrad Ludgate
91e8b7d22b add new auth proxy backend with new codec 2024-09-12 16:36:44 +01:00
Folke Behrens
f47401f2e9 Merge branch 'cloneable/pglb-tls' into pglb 2024-09-12 15:47:21 +01:00
Folke Behrens
469597fdb6 TLS conn accept 2024-09-12 15:47:05 +01:00
Conrad Ludgate
2af5352708 add auth proxy codec 2024-09-12 15:17:01 +01:00
Conrad Ludgate
fbc37acfdf add auth proxy client connection handling 2024-09-12 12:44:48 +01:00
Folke Behrens
b71bf47c33 Merge branch 'cloneable/pglb-passthrough' into pglb 2024-09-12 12:08:58 +01:00
Folke Behrens
d653d7c62c Frontend TCP listener 2024-09-12 12:06:23 +01:00
Conrad Ludgate
52b73185f9 rename stuff 2024-09-12 11:50:59 +01:00
Conrad Ludgate
dc41d108e8 add conn state with heartbeat system 2024-09-12 11:47:08 +01:00
Conrad Ludgate
02e15b7bbb build server config and endpoint 2024-09-12 11:33:38 +01:00
Conrad Ludgate
864bdf3528 init pglb 2024-09-12 11:13:38 +01:00
20 changed files with 2266 additions and 40 deletions

172
Cargo.lock generated
View File

@@ -919,7 +919,7 @@ version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@@ -928,7 +928,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.52",
]
@@ -947,9 +947,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.4.1"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "block-buffer"
@@ -1044,6 +1044,12 @@ dependencies = [
"libc",
]
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
name = "cexpr"
version = "0.6.0"
@@ -1371,9 +1377,9 @@ dependencies = [
[[package]]
name = "core-foundation"
version = "0.9.3"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
@@ -1381,9 +1387,9 @@ dependencies = [
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "cpufeatures"
@@ -1489,7 +1495,7 @@ version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"crossterm_winapi",
"libc",
"parking_lot 0.12.1",
@@ -1675,7 +1681,7 @@ version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"byteorder",
"chrono",
"diesel_derives",
@@ -2835,6 +2841,26 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]]
name = "jni"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
dependencies = [
"cesu8",
"combine",
"jni-sys",
"log",
"thiserror",
"walkdir",
]
[[package]]
name = "jni-sys"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
version = "0.1.26"
@@ -3058,7 +3084,7 @@ dependencies = [
"measured-derive",
"memchr",
"parking_lot 0.12.1",
"rustc-hash",
"rustc-hash 1.1.0",
"ryu",
]
@@ -3225,7 +3251,7 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"cfg-if",
"libc",
"memoffset 0.9.0",
@@ -3247,7 +3273,7 @@ version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"crossbeam-channel",
"filetime",
"fsevent-sys",
@@ -3416,9 +3442,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.18.0"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "oorandom"
@@ -4302,7 +4328,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731e0d9356b0c25f16f33b5be79b1c57b562f141ebfcdb0ad8ac2c13a24293b4"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"chrono",
"flate2",
"hex",
@@ -4317,7 +4343,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"chrono",
"hex",
]
@@ -4455,6 +4481,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"prometheus",
"quinn",
"rand 0.8.5",
"rand_distr",
"rcgen",
@@ -4468,7 +4495,7 @@ dependencies = [
"routerify",
"rsa",
"rstest",
"rustc-hash",
"rustc-hash 1.1.0",
"rustls 0.22.4",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
@@ -4517,6 +4544,55 @@ dependencies = [
"serde",
]
[[package]]
name = "quinn"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.0.0",
"rustls 0.23.7",
"socket2 0.5.5",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6"
dependencies = [
"bytes",
"rand 0.8.5",
"ring 0.17.6",
"rustc-hash 2.0.0",
"rustls 0.23.7",
"rustls-platform-verifier",
"slab",
"thiserror",
"tinyvec",
"tracing",
]
[[package]]
name = "quinn-udp"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285"
dependencies = [
"libc",
"once_cell",
"socket2 0.5.5",
"tracing",
"windows-sys 0.52.0",
]
[[package]]
name = "quote"
version = "1.0.35"
@@ -5119,6 +5195,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
[[package]]
name = "rustc_version"
version = "0.4.0"
@@ -5143,7 +5225,7 @@ version = "0.38.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.6.0",
"errno",
"libc",
"linux-raw-sys 0.4.13",
@@ -5176,6 +5258,20 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
dependencies = [
"once_cell",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.2"
@@ -5226,6 +5322,33 @@ version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
[[package]]
name = "rustls-platform-verifier"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls 0.23.7",
"rustls-native-certs 0.7.0",
"rustls-platform-verifier-android",
"rustls-webpki 0.102.2",
"security-framework",
"security-framework-sys",
"webpki-roots 0.26.1",
"winapi",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]]
name = "rustls-webpki"
version = "0.100.2"
@@ -5419,22 +5542,23 @@ dependencies = [
[[package]]
name = "security-framework"
version = "2.9.1"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
dependencies = [
"bitflags 1.3.2",
"bitflags 2.6.0",
"core-foundation",
"core-foundation-sys",
"libc",
"num-bigint",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.9.0"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
dependencies = [
"core-foundation-sys",
"libc",

View File

@@ -112,6 +112,9 @@ ecdsa = "0.16"
p256 = "0.13"
rsa = "0.9"
quinn = { version = "0.11", features = [] }
rcgen.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
@@ -119,7 +122,6 @@ camino-tempfile.workspace = true
fallible-iterator.workspace = true
tokio-tungstenite.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
walkdir.workspace = true

View File

@@ -4,9 +4,9 @@ pub mod backend;
pub use backend::Backend;
mod credentials;
pub use credentials::ComputeUserInfoMaybeEndpoint;
pub(crate) use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern,
};
mod password_hack;
@@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl {
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) struct AuthError(Box<AuthErrorImpl>);
pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {

View File

@@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
}
}
pub(crate) struct ComputeCredentials {
pub struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
pub(crate) keys: ComputeCredentialKeys,
}

View File

@@ -16,7 +16,7 @@ use thiserror::Error;
use tracing::{info, warn};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub(crate) enum ComputeUserInfoParseError {
pub enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
@@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError {
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ComputeUserInfoMaybeEndpoint {
pub(crate) user: RoleName,
pub(crate) endpoint_id: Option<EndpointId>,
pub(crate) options: NeonOptions,
pub struct ComputeUserInfoMaybeEndpoint {
pub user: RoleName,
pub endpoint_id: Option<EndpointId>,
pub options: NeonOptions,
}
impl ComputeUserInfoMaybeEndpoint {
@@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni(
}
impl ComputeUserInfoMaybeEndpoint {
pub(crate) fn parse(
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,

View File

@@ -0,0 +1,230 @@
mod classic;
mod hacks;
use tracing::info;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint};
use crate::auth_proxy::validate_password_and_exchange;
use crate::console::provider::ConsoleBackend;
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::scram;
use crate::stream::AuthProxyStreamExt;
use crate::{
config::AuthenticationConfig,
console::{self, provider::CachedNodeInfo, Api},
};
use super::AuthProxyStream;
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
}
impl std::fmt::Display for Backend<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Console(api, ()) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
#[cfg(any(test, feature = "testing"))]
ConsoleBackend::Postgres(endpoint) => {
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
}
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
}
}
}
impl<T> Backend<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> Backend<'_, &T> {
match self {
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
}
}
}
impl<'a, T> Backend<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
match self {
Self::Console(c, x) => Backend::Console(c, f(x)),
}
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info) = match user_info.try_into() {
Err(info) => {
todo!()
// let res = hacks::password_hack_no_authentication(info, client).await?;
// let password = match res.keys {
// ComputeCredentialKeys::Password(p) => p,
// ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
// unreachable!("password hack should return a password")
// }
// };
// (res.info, Some(password))
}
Ok(info) => info,
};
dbg!("fetching user's authentication info");
let cached_secret = api
.get_role_secret(&RequestMonitoring::test(), &info)
.await?;
let (cached_entry, secret) = cached_secret.take_value();
let secret = if let Some(secret) = secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
dbg!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
};
match authenticate_with_secret(secret, info, client, config).await {
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
}
}
async fn authenticate_with_secret(
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut AuthProxyStream,
// unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
// if let Some(password) = unauthenticated_password {
// let ep = EndpointIdInt::from(&info.endpoint);
// let auth_outcome =
// validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
// let keys = match auth_outcome {
// crate::sasl::Outcome::Success(key) => key,
// crate::sasl::Outcome::Failure(reason) => {
// info!("auth backend failed with an error: {reason}");
// return Err(auth::AuthError::auth_failed(&*info.user));
// }
// };
// // we have authenticated the password
// client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
// return Ok(ComputeCredentials { info, keys });
// }
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub fn get_user(&self) -> &str {
match self {
Self::Console(_, user_info) => &user_info.user,
}
}
pub async fn authenticate(
self,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::Console(api, user_info) => {
dbg!("authenticating...");
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(&*api, user_info, client, config).await?;
Backend::Console(api, credentials)
}
};
dbg!("user successfully authenticated");
Ok(res)
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::Console(_, creds) => &creds.keys,
}
}
}

View File

@@ -0,0 +1,69 @@
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys},
auth_proxy::{self, AuthFlow, AuthProxyStream},
compute,
config::AuthenticationConfig,
console::AuthSecret,
sasl,
};
use tracing::{info, warn};
pub(super) async fn authenticate(
creds: ComputeUserInfo,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthSecret::Scram(secret) => {
dbg!("auth endpoint chooses SCRAM");
let scram = auth_proxy::Scram(&secret);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
.await
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*creds.user));
}
};
compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
}
}
};
Ok(ComputeCredentials {
info: creds,
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
scram_keys,
)),
})
}

View File

@@ -0,0 +1,36 @@
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::{
auth,
auth_proxy::{self, AuthFlow, AuthProxyStream},
};
use tracing::{info, warn};
/// Workaround for clients which don't provide an endpoint (project) name.
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub(crate) async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut AuthProxyStream,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth_proxy::PasswordHack)
.await?
.get_password()
.await?;
info!(project = &*payload.endpoint, "received missing parameter");
// Report tentative success; compute node will check the password anyway.
Ok(ComputeCredentials {
info: ComputeUserInfo {
user: info.user,
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
})
}

View File

@@ -0,0 +1,183 @@
//! Main authentication flow.
use super::{AuthProxyStream, PasswordHackPayload};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthErrorImpl},
config::TlsServerEndPoint,
console::AuthSecret,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::AuthProxyStreamExt,
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::task_local;
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(pub(crate) &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
}
}
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut AuthProxyStream,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
task_local! {
pub(crate) static TLS_SERVER_END_POINT: TlsServerEndPoint;
}
/// Initial state of the stream wrapper.
impl<'a> AuthFlow<'a, Begin> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self {
let tls_server_end_point = TLS_SERVER_END_POINT.get();
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, M>> {
dbg!("sending auth begin message");
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl AuthFlow<'_, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> auth::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let payload = PasswordHackPayload::parse(password)
// If we ended up here and the payload is malformed, it means that
// the user neither enabled SNI nor resorted to any other method
// for passing the project name we rely on. We should show them
// the most helpful error message and point to the documentation.
.ok_or(AuthErrorImpl::MissingEndpointName)?;
Ok(payload)
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl AuthFlow<'_, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> auth::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret) = self.state;
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(auth::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let outcome = sasl::SaslStream2::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
}
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> auth::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
};
let keys = crate::compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: scram_secret.server_key.as_bytes(),
};
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
tokio_postgres::config::AuthKeys::ScramSha256(keys),
)))
}
}
}

View File

@@ -0,0 +1,17 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
mod password_hack;
use password_hack::PasswordHackPayload;
mod flow;
pub(crate) use flow::*;
use quinn::{RecvStream, SendStream};
use tokio::io::Join;
use tokio_util::codec::Framed;
use crate::PglbCodec;
pub type AuthProxyStream = Framed<Join<RecvStream, SendStream>, PglbCodec>;

View File

@@ -0,0 +1,121 @@
//! Payload for ad hoc authentication method for clients that don't support SNI.
//! See the `impl` for [`super::backend::Backend<ClientCredentials>`].
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
use bstr::ByteSlice;
use crate::EndpointId;
pub(crate) struct PasswordHackPayload {
pub(crate) endpoint: EndpointId,
pub(crate) password: Vec<u8>,
}
impl PasswordHackPayload {
pub(crate) fn parse(bytes: &[u8]) -> Option<Self> {
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
let separators = [";", "$"];
for sep in separators {
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
let endpoint = endpoint.to_str().ok()?;
return Some(Self {
endpoint: parse_endpoint_param(endpoint)?.into(),
password: password.to_owned(),
});
}
}
None
}
}
pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> {
bytes
.strip_prefix("project=")
.or_else(|| bytes.strip_prefix("endpoint="))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_endpoint_param_fn() {
let input = "";
assert!(parse_endpoint_param(input).is_none());
let input = "project=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "project=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "endpoint=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "endpoint=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "other_option=foobar";
assert!(parse_endpoint_param(input).is_none());
}
#[test]
fn parse_password_hack_payload_project() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=;";
let payload: PasswordHackPayload =
PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"project=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_endpoint() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=;";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_dollar() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=$";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar$pass$word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass$word");
}
}

242
proxy/src/bin/auth_proxy.rs Normal file
View File

@@ -0,0 +1,242 @@
use std::{sync::Arc, time::Duration};
use clap::Parser;
use proxy::{
auth::backend::AuthRateLimiter,
auth_proxy::{backend::MaybeOwned, Backend},
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
console::{
caches::ApiCaches,
locks::ApiLocks,
provider::{neon::Api, ConsoleBackend},
},
http,
metrics::Metrics,
proxy::{handle_stream, AuthProxyConfig},
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
scram::threadpool::ThreadPool,
};
use quinn::{crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, VarInt};
use tokio::{
io::AsyncWriteExt,
select,
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::task::TaskTracker;
/// Neon proxy/router
#[derive(Parser)]
#[command(about)]
struct ProxyCliArgs {
/// cloud API endpoint for authenticating users
#[clap(
short,
long,
default_value = "http://localhost:3000/authenticate_proxy_request"
)]
auth_endpoint: String,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
wake_compute_lock: String,
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
/// Whether to retry the wake_compute request
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
wake_compute_retry: String,
}
#[tokio::main]
async fn main() {
let args = ProxyCliArgs::parse();
let server = "127.0.0.1:5634".parse().unwrap();
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
let crypto = quinn::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();
let crypto = QuicClientConfig::try_from(crypto).unwrap();
let config = quinn::ClientConfig::new(Arc::new(crypto));
endpoint.set_default_client_config(config);
let mut int = signal(SignalKind::interrupt()).unwrap();
let mut term = signal(SignalKind::terminate()).unwrap();
let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap();
let mut interval = interval(Duration::from_secs(2));
let tasks = TaskTracker::new();
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let backend = {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap();
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse().unwrap();
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse().unwrap();
let caches = Box::leak(Box::new(ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse().unwrap();
let locks = Box::leak(Box::new(
ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse().unwrap();
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap();
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter);
let api = ConsoleBackend::Console(api);
Backend::Console(MaybeOwned::Owned(api), ())
};
let auth = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let config = &*Box::leak(Box::new(AuthProxyConfig { backend, auth }));
loop {
select! {
_ = int.recv() => break,
_ = term.recv() => break,
_ = interval.tick() => {
let mut stream = conn.open_uni().await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
stream = conn.accept_bi() => {
let (send, recv) = stream.unwrap();
tasks.spawn(async move {
handle_stream(config, send, recv).await.inspect_err(|e| println!("err {e:?}"))
});
}
}
}
// graceful shutdown
{
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"shutdown").await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
tasks.close();
tasks.wait().await;
conn.close(VarInt::from_u32(1), b"graceful shutdown");
}
#[derive(Copy, Clone, Debug)]
struct NoVerify;
impl danger::ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<danger::ServerCertVerified, quinn::rustls::Error> {
Ok(danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
}
}

653
proxy/src/bin/pglb.rs Normal file
View File

@@ -0,0 +1,653 @@
use std::{
convert::Infallible,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use itertools::Itertools;
use pq_proto::BeMessage;
use proxy::{
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join},
net::{TcpListener, TcpStream},
select,
time::timeout,
};
use tokio_rustls::server::TlsStream;
use tokio_util::codec::Framed;
use tracing::{error, warn};
type AuthConnId = usize;
#[derive(Debug)]
struct AuthConnState {
conns: Mutex<IndexMap<AuthConnId, AuthConn>>,
}
#[derive(Clone, Debug)]
struct AuthConn {
conn: Connection,
// latency info...
}
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[tokio::main]
async fn main() -> Result<()> {
let _logging_guard = proxy::logging::init().await?;
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?;
let auth_connections = Arc::new(AuthConnState {
conns: Mutex::new(IndexMap::new()),
});
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let frontend_config = frontent_tls_config()?;
let _frontend_handle = tokio::spawn(start_frontend(
"0.0.0.0:5432".parse()?,
frontend_config,
auth_connections.clone(),
));
quinn_handle.await.unwrap();
Ok(())
}
async fn endpoint_config(addr: SocketAddr) -> Result<Endpoint> {
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
params
.distinguished_name
.push(rcgen::DnType::CommonName, "pglb");
let key = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).context("keygen")?;
params.key_pair = Some(key);
let cert = rcgen::Certificate::from_params(params).context("cert")?;
let cert_der = cert.serialize_der().context("serialize")?;
let key_der = cert.serialize_private_key_der();
let cert = CertificateDer::from(cert_der);
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der));
let config = quinn::ServerConfig::with_single_cert(vec![cert], key).context("server config")?;
Endpoint::server(config, addr).context("endpoint")
}
async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
loop {
let incoming = ep.accept().await.expect("quinn server should not crash");
let state = state.clone();
tokio::spawn(async move {
let conn = incoming.await.unwrap();
let conn_id = conn.stable_id();
println!("[{conn_id:?}] new conn");
state
.conns
.lock()
.unwrap()
.insert(conn_id, AuthConn { conn: conn.clone() });
// heartbeat loop
loop {
match timeout(Duration::from_secs(10), conn.accept_uni()).await {
Ok(Ok(mut heartbeat_stream)) => {
let data = heartbeat_stream.read_to_end(128).await.unwrap();
if data.starts_with(b"shutdown") {
println!("[{conn_id:?}] conn shutdown");
break;
}
// else update latency info
}
Ok(Err(conn_err)) => {
println!("[{conn_id:?}] conn err {conn_err:?}");
break;
}
Err(_) => {
println!("[{conn_id:?}] conn timeout err");
break;
}
}
}
state.conns.lock().unwrap().remove(&conn_id);
let conn_closed = conn.closed().await;
println!("[{conn_id:?}] conn closed {conn_closed:?}");
});
}
}
fn frontent_tls_config() -> Result<TlsConfig> {
let (cert, key) = (
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
PrivateKeyDer::Pkcs8(
rustls_pemfile::pkcs8_private_keys(&mut &*std::fs::read("proxy.key").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
),
);
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?
.into();
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
Ok(TlsConfig {
config,
cert_resolver: Arc::new(cert_resolver),
})
}
async fn start_frontend(
addr: SocketAddr,
tls: TlsConfig,
state: Arc<AuthConnState>,
) -> Result<Infallible> {
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
println!("starting");
let connections = tokio_util::task::task_tracker::TaskTracker::new();
loop {
match listener.accept().await {
Ok((socket, client_addr)) => {
println!("accepted");
let conn = PglbConn::new(&state, &tls)?;
connections.spawn(conn.handle(socket, client_addr));
}
Err(e) => {
error!("connection accept error: {e}");
}
}
}
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}
#[derive(Debug)]
struct PglbConn<S: PglbConnState> {
inner: PglbConnInner,
state: S,
}
#[derive(Debug)]
struct PglbConnInner {
tls_config: TlsConfig,
auth_conns: Arc<AuthConnState>,
}
trait PglbConnState: std::fmt::Debug {}
impl PglbConnState for Start {}
impl PglbConnState for ClientConnect {}
impl PglbConnState for AuthPassthrough {}
impl PglbConnState for ComputeConnect {}
impl PglbConnState for ComputePassthrough {}
impl PglbConnState for End {}
#[derive(Debug)]
struct Start;
impl PglbConn<Start> {
fn new(auth_conns: &Arc<AuthConnState>, tls_config: &TlsConfig) -> Result<Self> {
Ok(PglbConn {
inner: PglbConnInner {
auth_conns: Arc::clone(auth_conns),
tls_config: tls_config.clone(),
},
state: Start,
})
}
async fn handle(
self,
client_stream: TcpStream,
client_addr: SocketAddr,
) -> Result<PglbConn<End>> {
self.handle_start(client_stream, client_addr)
.await?
.handle_client_connect()
.await?
.handle_auth_passthrough()
.await?
.handle_compute_connect()
.await?
.handle_compute_passthrough()
.await
}
}
impl PglbConn<Start> {
async fn handle_start(
self,
mut client_stream: TcpStream,
client_addr: SocketAddr,
) -> Result<PglbConn<ClientConnect>> {
match client_stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
bail!("socket option error: {e}");
}
};
// TODO: HAProxy protocol
let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
bail!("check_for_ssl_request: {e}");
}
};
let (client_stream, payload) = if tls_requested {
println!("starting tls upgrade");
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::EncryptionResponse(true)).unwrap();
client_stream.write_all(&buf).await?;
let (stream, tls_server_end_point, server_name) =
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
Err(e) => {
bail!("tls_upgrade: {e}");
}
};
(
stream,
ConnectionInitiatedPayload {
tls_server_end_point,
server_name,
ip_addr: client_addr.ip(),
},
)
} else {
// TODO: support unsecured connections?
bail!("closing non-TLS connection");
};
println!("tls done");
Ok(PglbConn {
inner: self.inner,
state: ClientConnect {
client_stream,
payload,
},
})
}
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
println!("checking for ssl request");
let mut buf = vec![0u8; 8];
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
bail!("EOF");
}
assert_eq!(buf.len(), 8); // TODO: loop, read more
if buf.len() != 8
|| buf[0..4] != 8u32.to_be_bytes()
|| buf[4..8] != 80877103u32.to_be_bytes()
{
return Ok(false);
}
stream.read_exact(&mut buf).await?;
Ok(true)
}
async fn tls_upgrade(
stream: TcpStream,
tls: TlsConfig,
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
.accept(stream)
.await?;
let conn_info = tls_stream.get_ref().1;
let server_name = conn_info.server_name().map(|s| s.to_string());
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
let alpn = String::from_utf8_lossy(other);
warn!(%alpn, "unexpected ALPN");
bail!("protocol violation");
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(server_name.as_deref())
.ok_or(anyhow!("missing cert"))?;
Ok((tls_stream, tls_server_end_point, server_name))
}
}
#[derive(Debug)]
struct ClientConnect {
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
}
impl PglbConn<ClientConnect> {
async fn handle_client_connect(self) -> Result<PglbConn<AuthPassthrough>> {
let auth_conn = {
let conns = self.inner.auth_conns.conns.lock().unwrap();
if conns.is_empty() {
bail!("no auth proxies avaiable");
}
let mut rng = rand::thread_rng();
conns
.get_index(rng.gen_range(0..conns.len()))
.unwrap()
.1
.clone()
// TODO: check closed?
};
println!("connecting to {}", auth_conn.conn.stable_id());
let (send, recv) = auth_conn.conn.open_bi().await?;
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
auth_stream
.send(proxy::PglbMessage::Control(
proxy::PglbControlMessage::ConnectionInitiated(self.state.payload),
))
.await?;
Ok(PglbConn {
inner: self.inner,
state: AuthPassthrough {
client_stream: Framed::new(
self.state.client_stream,
PgRawCodec {
start_or_ssl_request: true,
},
),
auth_stream,
},
})
}
}
#[derive(Debug)]
struct AuthPassthrough {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
}
impl PglbConn<AuthPassthrough> {
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
let mut client_stream = self.state.client_stream;
let mut auth_stream = self.state.auth_stream;
loop {
select! {
biased;
msg = auth_stream.next() => {
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
client_stream.send(PgRawMessage::Generic { payload }).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
bail!("auth proxy sent unexpected message");
}
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
println!("socket");
return Ok(PglbConn {
inner: self.inner,
state: ComputeConnect {
client_stream,
auth_stream,
compute_socket:socket,
},
});
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = client_stream.next() => {
match msg.context("client disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
auth_stream.send(proxy::PglbMessage::Postgres(
payload
)).await?;
}
}
}
}
}
}
}
#[derive(Debug)]
struct ComputeConnect {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
compute_socket: SocketAddr,
}
impl PglbConn<ComputeConnect> {
async fn handle_compute_connect(self) -> Result<PglbConn<ComputePassthrough>> {
let ComputeConnect {
client_stream,
mut auth_stream,
compute_socket,
} = self.state;
let compute_stream = TcpStream::connect(compute_socket).await?;
compute_stream
.set_nodelay(true)
.context("socket option error")?;
let mut compute_stream = Framed::new(
compute_stream,
PgRawCodec {
start_or_ssl_request: false,
},
);
let mut resps = 4;
loop {
select! {
msg = auth_stream.next() => {
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
compute_stream.send(PgRawMessage::Generic { payload } ).await?;
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
println!("establish");
return Ok(PglbConn {
inner: self.inner,
state: ComputePassthrough {
client_stream,
compute_stream,
},
});
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) |
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = compute_stream.next(), if resps > 0 => {
match msg.context("compute disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
resps -= 1;
auth_stream.send(proxy::PglbMessage::Postgres(
payload
)).await?;
}
}
}
}
}
}
}
#[derive(Debug)]
struct ComputePassthrough {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
compute_stream: Framed<TcpStream, PgRawCodec>,
}
impl PglbConn<ComputePassthrough> {
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
let ComputePassthrough {
client_stream,
compute_stream,
} = self.state;
let mut client_parts = client_stream.into_parts();
let mut compute_parts = compute_stream.into_parts();
assert!(compute_parts.write_buf.is_empty());
assert!(client_parts.write_buf.is_empty());
client_parts.io.write_all(&compute_parts.read_buf).await?;
compute_parts.io.write_all(&client_parts.read_buf).await?;
drop(client_parts.read_buf);
drop(client_parts.write_buf);
drop(compute_parts.read_buf);
drop(compute_parts.write_buf);
copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?;
Ok(PglbConn {
inner: self.inner,
state: End,
})
}
}
#[derive(Debug)]
struct End;
#[derive(Debug)]
struct PgRawCodec {
start_or_ssl_request: bool,
}
impl tokio_util::codec::Encoder<PgRawMessage> for PgRawCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
item.encode(dst)
}
}
impl tokio_util::codec::Decoder for PgRawCodec {
type Item = PgRawMessage;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.start_or_ssl_request {
match PgRawMessage::decode(src, true)? {
msg @ Some(PgRawMessage::Start(..)) => {
self.start_or_ssl_request = false;
Ok(msg)
}
msg @ Some(PgRawMessage::SslRequest) => Ok(msg),
Some(PgRawMessage::Generic { .. }) => unreachable!(),
None => Ok(None),
}
} else {
match PgRawMessage::decode(src, false)? {
Some(PgRawMessage::Start(..)) => unreachable!(),
Some(PgRawMessage::SslRequest) => unreachable!(),
msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg),
None => Ok(None),
}
}
}
}
#[derive(Debug)]
pub enum PgRawMessage {
SslRequest,
Start(BytesMut),
Generic { payload: BytesMut },
}
impl PgRawMessage {
fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> {
match self {
Self::SslRequest => {
dst.put_u32(8);
dst.put_u32(80877103);
}
Self::Start(payload) | Self::Generic { payload } => {
dst.put_slice(payload);
}
}
Ok(())
}
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
let extra = if start { 0 } else { 1 };
if src.remaining() < 4 + extra {
src.reserve(4 + extra);
return Ok(None);
}
let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
Ok(Some(PgRawMessage::SslRequest))
} else if start {
Ok(Some(PgRawMessage::Start(src.split_to(length))))
} else {
Ok(Some(PgRawMessage::Generic {
payload: src.split_to(length),
}))
}
}
}

View File

@@ -13,6 +13,7 @@ use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
@@ -149,7 +150,7 @@ pub fn configure_tls(
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,

View File

@@ -125,7 +125,7 @@ impl RequestMonitoring {
Self(TryLock::new(inner))
}
#[cfg(test)]
// #[cfg(test)]
pub(crate) fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
}

View File

@@ -82,15 +82,23 @@
impl_trait_overcaptures,
)]
use std::{convert::Infallible, future::Future};
use std::{
convert::Infallible,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
use anyhow::{bail, Context};
use bytes::{Buf, BufMut};
use config::TlsServerEndPoint;
use intern::{EndpointIdInt, EndpointIdTag, InternId};
use serde::{Deserialize, Serialize};
use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::warn;
pub mod auth;
pub mod auth_proxy;
pub mod cache;
pub mod cancellation;
pub mod compute;
@@ -274,3 +282,132 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
#[derive(Debug)]
pub struct PglbCodec;
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
match item {
PglbMessage::Control(ctrl) => {
dst.put_u8(1);
match ctrl {
PglbControlMessage::ConnectionInitiated(msg) => {
let encode = serde_json::to_string(&msg).context("ser")?;
dst.put_u32(1 + encode.len() as u32);
dst.put_u8(0);
dst.put(encode.as_bytes());
}
PglbControlMessage::ConnectToCompute { socket } => match socket {
SocketAddr::V4(v4) => {
dst.put_u32(1 + 4 + 2);
dst.put_u8(1);
dst.put_u32(v4.ip().to_bits());
dst.put_u16(v4.port());
}
SocketAddr::V6(v6) => {
dst.put_u32(1 + 16 + 2);
dst.put_u8(1);
dst.put_u128(v6.ip().to_bits());
dst.put_u16(v6.port());
}
},
PglbControlMessage::ComputeEstablish => {
dst.put_u32(1);
dst.put_u8(2);
}
}
}
PglbMessage::Postgres(pg) => {
dst.put_u8(0);
dst.put_u32(pg.len() as u32);
dst.put(pg);
}
}
Ok(())
}
}
impl tokio_util::codec::Decoder for PglbCodec {
type Item = PglbMessage;
type Error = anyhow::Error;
fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if dst.remaining() < 5 {
dst.reserve(5);
return Ok(None);
}
let msg = dst[0];
let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize;
if len + 5 > dst.remaining() {
dst.reserve(len + 5);
return Ok(None);
}
dst.advance(5);
let mut payload = dst.split_to(len);
match msg {
// postgres
0 => Ok(Some(PglbMessage::Postgres(payload))),
// control
1 => {
if payload.is_empty() {
bail!("invalid ctrl message")
}
let ctrl_msg = payload.split_to(1)[0];
let ctrl_msg = match ctrl_msg {
0 => PglbControlMessage::ConnectionInitiated(
serde_json::from_slice(&payload).context("deser")?,
),
// ipv4 socket
1 if len == 7 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())),
payload.get_u16(),
),
},
// ipv6 socket
1 if len == 19 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())),
payload.get_u16(),
),
},
2 if len == 1 => PglbControlMessage::ComputeEstablish,
_ => bail!("invalid ctrl message"),
};
Ok(Some(PglbMessage::Control(ctrl_msg)))
}
_ => bail!("invalid message"),
}
}
}
pub enum PglbMessage {
Control(PglbControlMessage),
Postgres(bytes::BytesMut),
}
pub enum PglbControlMessage {
// from pglb to auth proxy
ConnectionInitiated(ConnectionInitiatedPayload),
// from auth proxy to pglb
ConnectToCompute { socket: SocketAddr },
// from auth proxy to pglb
ComputeEstablish,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ConnectionInitiatedPayload {
pub tls_server_end_point: TlsServerEndPoint,
pub server_name: Option<String>,
pub ip_addr: IpAddr,
}

View File

@@ -7,9 +7,35 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use anyhow::bail;
use anyhow::Context;
use bytes::BytesMut;
use connect_compute::ComputeConnectBackend;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use futures::SinkExt;
use futures::TryStreamExt;
use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ChannelBinding;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use pq_proto::FeStartupPacket;
use quinn::RecvStream;
use quinn::SendStream;
use tokio::io::join;
use tokio_postgres::config::AuthKeys;
use tokio_util::codec::Framed;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeCredentials;
use crate::auth_proxy::AuthProxyStream;
use crate::auth_proxy::TLS_SERVER_END_POINT;
use crate::console::NodeInfo;
use crate::stream::AuthProxyStreamExt;
use crate::ConnectionInitiatedPayload;
use crate::PglbControlMessage;
use crate::PglbMessage;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -30,6 +56,8 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -377,10 +405,10 @@ async fn prepare_client_connection<P>(
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
impl NeonOptions {
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
pub fn parse_params(params: &StartupMessageParams) -> Self {
params
.options_raw()
.map(Self::parse_from_iter)
@@ -431,3 +459,157 @@ pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
let (_, [k, v]) = cap.extract();
Some((k, v))
}
pub struct AuthProxyConfig {
pub backend: crate::auth_proxy::Backend<'static, ()>,
pub auth: crate::config::AuthenticationConfig,
}
pub async fn handle_stream(
config: &'static AuthProxyConfig,
send: SendStream,
recv: RecvStream,
) -> anyhow::Result<()> {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
// recv connection metadata
let first_msg = stream.try_next().await?;
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg
else {
panic!("invalid first msg")
};
println!("new conn: {conn_info:?}");
// read startup packet
let startup = stream.read_startup_packet().await?;
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
panic!("invalid startup message")
};
println!("params: {params:?}");
let user_info = auth_with_user(&mut stream, config, &conn_info, &params).await?;
println!("authenticated");
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
println!("woke compute");
let addr: IpAddr = node_info.config.get_host()?.parse()?;
let socket = SocketAddr::new(addr, node_info.config.get_ports()[0]);
// tell pglb that the compute is up
stream
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
socket,
}))
.await?;
// send startup message to compute
let mut buf = BytesMut::new();
frontend::startup_message(params.iter(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
auth_with_compute(&mut stream, user_info.get_keys()).await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
.await?;
Ok(())
}
async fn auth_with_user(
stream: &mut AuthProxyStream,
config: &'static AuthProxyConfig,
conn_info: &ConnectionInitiatedPayload,
params: &StartupMessageParams,
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
dbg!("auth...");
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").context("missing user")?.into(),
endpoint_id: conn_info
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(params),
};
dbg!("parsed used info");
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
.scope(
conn_info.tls_server_end_point,
user_info.authenticate(stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
Ok(user_info)
}
async fn auth_with_compute(
stream: &mut AuthProxyStream,
keys: &ComputeCredentialKeys,
) -> anyhow::Result<()> {
let ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(scram_keys)) = keys else {
bail!("missing keys");
};
// compute offers sasl
stream
.read_backend_message(|m| match m {
Message::AuthenticationSasl(_body) => Ok(()),
_ => bail!("invalid message"),
})
.await?;
let mut buf = BytesMut::new();
// send auth message
let mut scram = ScramSha256::new_with_keys(*scram_keys, ChannelBinding::unsupported());
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let cont_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslContinue(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.update(cont_body.data()).await?;
frontend::sasl_response(scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let final_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslFinal(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.finish(final_body.data())?;
// wait for ok.
stream
.read_backend_message(|m| match m {
Message::AuthenticationOk => Ok(()),
_ => bail!("invalid message"),
})
.await?;
Ok(())
}

View File

@@ -9,6 +9,7 @@
mod channel_binding;
mod messages;
mod stream;
mod stream2;
use crate::error::{ReportableError, UserFacingError};
use std::io;
@@ -17,6 +18,7 @@ use thiserror::Error;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
pub(crate) use stream2::SaslStream2;
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]

85
proxy/src/sasl/stream2.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Abstraction for the string-oriented SASL protocols.
use crate::{
auth_proxy::AuthProxyStream,
sasl::{messages::ServerMessage, Mechanism},
stream::AuthProxyStreamExt,
};
use std::io;
use tracing::info;
use super::Outcome;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream2<'a> {
/// The underlying stream.
stream: &'a mut AuthProxyStream,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a> SaslStream2<'a> {
pub(crate) fn new(stream: &'a mut AuthProxyStream, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl SaslStream2<'_> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl SaslStream2<'_> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
}
impl SaslStream2<'_> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub(crate) async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> crate::sasl::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
use crate::sasl::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send(&ServerMessage::Final(&reply)).await?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
}
}
}

View File

@@ -1,8 +1,13 @@
use crate::auth_proxy::AuthProxyStream;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::PglbMessage;
use anyhow::{bail, Context};
use bytes::BytesMut;
use futures::{SinkExt, TryStreamExt};
use postgres_protocol::message::backend;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -294,3 +299,140 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
#[allow(async_fn_in_trait)]
pub(crate) trait AuthProxyStreamExt {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>;
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket>;
async fn read_message(&mut self) -> io::Result<FeMessage>;
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T>;
}
impl AuthProxyStreamExt for AuthProxyStream {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
let mut b = BytesMut::new();
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.start_send_unpin(PglbMessage::Postgres(b))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
}
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeStartupPacket::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeMessage::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
)),
}
}
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T> {
let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else {
bail!("invalid message");
};
let message = backend::Message::parse(&mut buf)?.context("missing")?;
f(message)
}
}