Compare commits

..

3 Commits

Author SHA1 Message Date
Alexander Bayandin
05a8ec269a wip 2024-09-12 21:12:44 +01:00
Stefan Radig
fcab61bdcd Prototype implementation for private access poc (#8976)
## Problem
For the Private Access POC we want users to be able to disable access
from the public proxy. To limit the number of changes this can be done
by configuring an IP allowlist [ "255.255.255.255" ]. For the Private
Access proxy a new commandline flag allows to disable IP allowlist
completely.

See
https://www.notion.so/neondatabase/Neon-Private-Access-POC-Proposal-8f707754e1ab4190ad5709da7832f020?d=887495c15e884aa4973f973a8a0a582a#7ac6ec249b524a74adbeddc4b84b8f5f
for details about the POC.,

## Summary of changes
- Adding the commandline flag is_private_access_proxy=true will disable
IP allowlist
2024-09-12 15:55:12 +01:00
Tristan Partin
9e3ead3689 Collect the last of on-demand WAL download in CreateReplicationSlot reverts
Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-09-12 11:31:38 +01:00
28 changed files with 95 additions and 2272 deletions

View File

@@ -0,0 +1 @@
FROM neondatabase/build-tools:pinned

View File

@@ -0,0 +1,23 @@
// https://containers.dev/implementors/json_reference/
{
"name": "Neon",
"build": {
"context": "..",
"dockerfile": "Dockerfile.devcontainer"
},
"postCreateCommand": {
"build neon": "BUILD_TYPE=debug CARGO_BUILD_FLAGS='--features=testing' mold -run make -s -j`nproc`",
"install python deps": "./scripts/pysync"
},
"customizations": {
"vscode": {
"extensions": [
"charliermarsh.ruff",
"github.vscode-github-actions",
"rust-lang.rust-analyzer"
]
}
}
}

172
Cargo.lock generated
View File

@@ -919,7 +919,7 @@ version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@@ -928,7 +928,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"rustc-hash",
"shlex",
"syn 2.0.52",
]
@@ -947,9 +947,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.6.0"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
[[package]]
name = "block-buffer"
@@ -1044,12 +1044,6 @@ dependencies = [
"libc",
]
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
name = "cexpr"
version = "0.6.0"
@@ -1377,9 +1371,9 @@ dependencies = [
[[package]]
name = "core-foundation"
version = "0.9.4"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
dependencies = [
"core-foundation-sys",
"libc",
@@ -1387,9 +1381,9 @@ dependencies = [
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
@@ -1495,7 +1489,7 @@ version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"crossterm_winapi",
"libc",
"parking_lot 0.12.1",
@@ -1681,7 +1675,7 @@ version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"byteorder",
"chrono",
"diesel_derives",
@@ -2841,26 +2835,6 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]]
name = "jni"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
dependencies = [
"cesu8",
"combine",
"jni-sys",
"log",
"thiserror",
"walkdir",
]
[[package]]
name = "jni-sys"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
version = "0.1.26"
@@ -3084,7 +3058,7 @@ dependencies = [
"measured-derive",
"memchr",
"parking_lot 0.12.1",
"rustc-hash 1.1.0",
"rustc-hash",
"ryu",
]
@@ -3251,7 +3225,7 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"cfg-if",
"libc",
"memoffset 0.9.0",
@@ -3273,7 +3247,7 @@ version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"crossbeam-channel",
"filetime",
"fsevent-sys",
@@ -3442,9 +3416,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.19.0"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
name = "oorandom"
@@ -4328,7 +4302,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731e0d9356b0c25f16f33b5be79b1c57b562f141ebfcdb0ad8ac2c13a24293b4"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"chrono",
"flate2",
"hex",
@@ -4343,7 +4317,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"chrono",
"hex",
]
@@ -4481,7 +4455,6 @@ dependencies = [
"postgres_backend",
"pq_proto",
"prometheus",
"quinn",
"rand 0.8.5",
"rand_distr",
"rcgen",
@@ -4495,7 +4468,7 @@ dependencies = [
"routerify",
"rsa",
"rstest",
"rustc-hash 1.1.0",
"rustc-hash",
"rustls 0.22.4",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
@@ -4544,55 +4517,6 @@ dependencies = [
"serde",
]
[[package]]
name = "quinn"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.0.0",
"rustls 0.23.7",
"socket2 0.5.5",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6"
dependencies = [
"bytes",
"rand 0.8.5",
"ring 0.17.6",
"rustc-hash 2.0.0",
"rustls 0.23.7",
"rustls-platform-verifier",
"slab",
"thiserror",
"tinyvec",
"tracing",
]
[[package]]
name = "quinn-udp"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285"
dependencies = [
"libc",
"once_cell",
"socket2 0.5.5",
"tracing",
"windows-sys 0.52.0",
]
[[package]]
name = "quote"
version = "1.0.35"
@@ -5195,12 +5119,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
[[package]]
name = "rustc_version"
version = "0.4.0"
@@ -5225,7 +5143,7 @@ version = "0.38.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
dependencies = [
"bitflags 2.6.0",
"bitflags 2.4.1",
"errno",
"libc",
"linux-raw-sys 0.4.13",
@@ -5258,20 +5176,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
dependencies = [
"once_cell",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.2"
@@ -5322,33 +5226,6 @@ version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
[[package]]
name = "rustls-platform-verifier"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afbb878bdfdf63a336a5e63561b1835e7a8c91524f51621db870169eac84b490"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls 0.23.7",
"rustls-native-certs 0.7.0",
"rustls-platform-verifier-android",
"rustls-webpki 0.102.2",
"security-framework",
"security-framework-sys",
"webpki-roots 0.26.1",
"winapi",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]]
name = "rustls-webpki"
version = "0.100.2"
@@ -5542,23 +5419,22 @@ dependencies = [
[[package]]
name = "security-framework"
version = "2.11.0"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
dependencies = [
"bitflags 2.6.0",
"bitflags 1.3.2",
"core-foundation",
"core-foundation-sys",
"libc",
"num-bigint",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.11.1"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
dependencies = [
"core-foundation-sys",
"libc",

View File

@@ -112,9 +112,6 @@ ecdsa = "0.16"
p256 = "0.13"
rsa = "0.9"
quinn = { version = "0.11", features = [] }
rcgen.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
@@ -122,6 +119,7 @@ camino-tempfile.workspace = true
fallible-iterator.workspace = true
tokio-tungstenite.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
walkdir.workspace = true

View File

@@ -4,9 +4,9 @@ pub mod backend;
pub use backend::Backend;
mod credentials;
pub use credentials::ComputeUserInfoMaybeEndpoint;
pub(crate) use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
};
mod password_hack;
@@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl {
#[derive(Debug, Error)]
#[error(transparent)]
pub struct AuthError(Box<AuthErrorImpl>);
pub(crate) struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {

View File

@@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
}
}
pub struct ComputeCredentials {
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
pub(crate) keys: ComputeCredentialKeys,
}
@@ -311,7 +311,9 @@ async fn auth_quirks(
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
if config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
@@ -603,6 +605,7 @@ mod tests {
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -16,7 +16,7 @@ use thiserror::Error;
use tracing::{info, warn};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ComputeUserInfoParseError {
pub(crate) enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
@@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError {
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ComputeUserInfoMaybeEndpoint {
pub user: RoleName,
pub endpoint_id: Option<EndpointId>,
pub options: NeonOptions,
pub(crate) struct ComputeUserInfoMaybeEndpoint {
pub(crate) user: RoleName,
pub(crate) endpoint_id: Option<EndpointId>,
pub(crate) options: NeonOptions,
}
impl ComputeUserInfoMaybeEndpoint {
@@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni(
}
impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
pub(crate) fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
@@ -538,4 +538,17 @@ mod tests {
));
Ok(())
}
#[test]
fn test_connection_blocker() {
fn check(v: serde_json::Value) -> bool {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
check_peer_addr_is_in_list(&peer_addr, &ip_list)
}
assert!(check(json!([])));
assert!(check(json!(["127.0.0.1"])));
assert!(!check(json!(["255.255.255.255"])));
}
}

View File

@@ -1,230 +0,0 @@
mod classic;
mod hacks;
use tracing::info;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint};
use crate::auth_proxy::validate_password_and_exchange;
use crate::console::provider::ConsoleBackend;
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::scram;
use crate::stream::AuthProxyStreamExt;
use crate::{
config::AuthenticationConfig,
console::{self, provider::CachedNodeInfo, Api},
};
use super::AuthProxyStream;
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
}
impl std::fmt::Display for Backend<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Console(api, ()) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
#[cfg(any(test, feature = "testing"))]
ConsoleBackend::Postgres(endpoint) => {
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
}
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
}
}
}
impl<T> Backend<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> Backend<'_, &T> {
match self {
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
}
}
}
impl<'a, T> Backend<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
match self {
Self::Console(c, x) => Backend::Console(c, f(x)),
}
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info) = match user_info.try_into() {
Err(info) => {
todo!()
// let res = hacks::password_hack_no_authentication(info, client).await?;
// let password = match res.keys {
// ComputeCredentialKeys::Password(p) => p,
// ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
// unreachable!("password hack should return a password")
// }
// };
// (res.info, Some(password))
}
Ok(info) => info,
};
dbg!("fetching user's authentication info");
let cached_secret = api
.get_role_secret(&RequestMonitoring::test(), &info)
.await?;
let (cached_entry, secret) = cached_secret.take_value();
let secret = if let Some(secret) = secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
dbg!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
};
match authenticate_with_secret(secret, info, client, config).await {
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
}
}
async fn authenticate_with_secret(
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut AuthProxyStream,
// unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
// if let Some(password) = unauthenticated_password {
// let ep = EndpointIdInt::from(&info.endpoint);
// let auth_outcome =
// validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
// let keys = match auth_outcome {
// crate::sasl::Outcome::Success(key) => key,
// crate::sasl::Outcome::Failure(reason) => {
// info!("auth backend failed with an error: {reason}");
// return Err(auth::AuthError::auth_failed(&*info.user));
// }
// };
// // we have authenticated the password
// client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
// return Ok(ComputeCredentials { info, keys });
// }
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub fn get_user(&self) -> &str {
match self {
Self::Console(_, user_info) => &user_info.user,
}
}
pub async fn authenticate(
self,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::Console(api, user_info) => {
dbg!("authenticating...");
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(&*api, user_info, client, config).await?;
Backend::Console(api, credentials)
}
};
dbg!("user successfully authenticated");
Ok(res)
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::Console(_, creds) => &creds.keys,
}
}
}

View File

@@ -1,69 +0,0 @@
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys},
auth_proxy::{self, AuthFlow, AuthProxyStream},
compute,
config::AuthenticationConfig,
console::AuthSecret,
sasl,
};
use tracing::{info, warn};
pub(super) async fn authenticate(
creds: ComputeUserInfo,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthSecret::Scram(secret) => {
dbg!("auth endpoint chooses SCRAM");
let scram = auth_proxy::Scram(&secret);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
.await
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*creds.user));
}
};
compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
}
}
};
Ok(ComputeCredentials {
info: creds,
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
scram_keys,
)),
})
}

View File

@@ -1,36 +0,0 @@
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::{
auth,
auth_proxy::{self, AuthFlow, AuthProxyStream},
};
use tracing::{info, warn};
/// Workaround for clients which don't provide an endpoint (project) name.
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub(crate) async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut AuthProxyStream,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth_proxy::PasswordHack)
.await?
.get_password()
.await?;
info!(project = &*payload.endpoint, "received missing parameter");
// Report tentative success; compute node will check the password anyway.
Ok(ComputeCredentials {
info: ComputeUserInfo {
user: info.user,
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
})
}

View File

@@ -1,183 +0,0 @@
//! Main authentication flow.
use super::{AuthProxyStream, PasswordHackPayload};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthErrorImpl},
config::TlsServerEndPoint,
console::AuthSecret,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::AuthProxyStreamExt,
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::task_local;
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(pub(crate) &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
}
}
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut AuthProxyStream,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
task_local! {
pub(crate) static TLS_SERVER_END_POINT: TlsServerEndPoint;
}
/// Initial state of the stream wrapper.
impl<'a> AuthFlow<'a, Begin> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self {
let tls_server_end_point = TLS_SERVER_END_POINT.get();
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, M>> {
dbg!("sending auth begin message");
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl AuthFlow<'_, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> auth::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let payload = PasswordHackPayload::parse(password)
// If we ended up here and the payload is malformed, it means that
// the user neither enabled SNI nor resorted to any other method
// for passing the project name we rely on. We should show them
// the most helpful error message and point to the documentation.
.ok_or(AuthErrorImpl::MissingEndpointName)?;
Ok(payload)
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl AuthFlow<'_, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> auth::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret) = self.state;
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(auth::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let outcome = sasl::SaslStream2::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
}
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> auth::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
};
let keys = crate::compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: scram_secret.server_key.as_bytes(),
};
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
tokio_postgres::config::AuthKeys::ScramSha256(keys),
)))
}
}
}

View File

@@ -1,17 +0,0 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
mod password_hack;
use password_hack::PasswordHackPayload;
mod flow;
pub(crate) use flow::*;
use quinn::{RecvStream, SendStream};
use tokio::io::Join;
use tokio_util::codec::Framed;
use crate::PglbCodec;
pub type AuthProxyStream = Framed<Join<RecvStream, SendStream>, PglbCodec>;

View File

@@ -1,121 +0,0 @@
//! Payload for ad hoc authentication method for clients that don't support SNI.
//! See the `impl` for [`super::backend::Backend<ClientCredentials>`].
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
use bstr::ByteSlice;
use crate::EndpointId;
pub(crate) struct PasswordHackPayload {
pub(crate) endpoint: EndpointId,
pub(crate) password: Vec<u8>,
}
impl PasswordHackPayload {
pub(crate) fn parse(bytes: &[u8]) -> Option<Self> {
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
let separators = [";", "$"];
for sep in separators {
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
let endpoint = endpoint.to_str().ok()?;
return Some(Self {
endpoint: parse_endpoint_param(endpoint)?.into(),
password: password.to_owned(),
});
}
}
None
}
}
pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> {
bytes
.strip_prefix("project=")
.or_else(|| bytes.strip_prefix("endpoint="))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_endpoint_param_fn() {
let input = "";
assert!(parse_endpoint_param(input).is_none());
let input = "project=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "project=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "endpoint=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "endpoint=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "other_option=foobar";
assert!(parse_endpoint_param(input).is_none());
}
#[test]
fn parse_password_hack_payload_project() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=;";
let payload: PasswordHackPayload =
PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"project=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_endpoint() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=;";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_dollar() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=$";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar$pass$word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass$word");
}
}

View File

@@ -1,242 +0,0 @@
use std::{sync::Arc, time::Duration};
use clap::Parser;
use proxy::{
auth::backend::AuthRateLimiter,
auth_proxy::{backend::MaybeOwned, Backend},
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
console::{
caches::ApiCaches,
locks::ApiLocks,
provider::{neon::Api, ConsoleBackend},
},
http,
metrics::Metrics,
proxy::{handle_stream, AuthProxyConfig},
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
scram::threadpool::ThreadPool,
};
use quinn::{crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, VarInt};
use tokio::{
io::AsyncWriteExt,
select,
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::task::TaskTracker;
/// Neon proxy/router
#[derive(Parser)]
#[command(about)]
struct ProxyCliArgs {
/// cloud API endpoint for authenticating users
#[clap(
short,
long,
default_value = "http://localhost:3000/authenticate_proxy_request"
)]
auth_endpoint: String,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
wake_compute_lock: String,
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
/// Whether to retry the wake_compute request
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
wake_compute_retry: String,
}
#[tokio::main]
async fn main() {
let args = ProxyCliArgs::parse();
let server = "127.0.0.1:5634".parse().unwrap();
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
let crypto = quinn::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();
let crypto = QuicClientConfig::try_from(crypto).unwrap();
let config = quinn::ClientConfig::new(Arc::new(crypto));
endpoint.set_default_client_config(config);
let mut int = signal(SignalKind::interrupt()).unwrap();
let mut term = signal(SignalKind::terminate()).unwrap();
let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap();
let mut interval = interval(Duration::from_secs(2));
let tasks = TaskTracker::new();
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let backend = {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap();
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse().unwrap();
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse().unwrap();
let caches = Box::leak(Box::new(ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse().unwrap();
let locks = Box::leak(Box::new(
ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse().unwrap();
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap();
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter);
let api = ConsoleBackend::Console(api);
Backend::Console(MaybeOwned::Owned(api), ())
};
let auth = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let config = &*Box::leak(Box::new(AuthProxyConfig { backend, auth }));
loop {
select! {
_ = int.recv() => break,
_ = term.recv() => break,
_ = interval.tick() => {
let mut stream = conn.open_uni().await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
stream = conn.accept_bi() => {
let (send, recv) = stream.unwrap();
tasks.spawn(async move {
handle_stream(config, send, recv).await.inspect_err(|e| println!("err {e:?}"))
});
}
}
}
// graceful shutdown
{
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"shutdown").await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
tasks.close();
tasks.wait().await;
conn.close(VarInt::from_u32(1), b"graceful shutdown");
}
#[derive(Copy, Clone, Debug)]
struct NoVerify;
impl danger::ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<danger::ServerCertVerified, quinn::rustls::Error> {
Ok(danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
}
}

View File

@@ -224,6 +224,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),

View File

@@ -1,653 +0,0 @@
use std::{
convert::Infallible,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use itertools::Itertools;
use pq_proto::BeMessage;
use proxy::{
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join},
net::{TcpListener, TcpStream},
select,
time::timeout,
};
use tokio_rustls::server::TlsStream;
use tokio_util::codec::Framed;
use tracing::{error, warn};
type AuthConnId = usize;
#[derive(Debug)]
struct AuthConnState {
conns: Mutex<IndexMap<AuthConnId, AuthConn>>,
}
#[derive(Clone, Debug)]
struct AuthConn {
conn: Connection,
// latency info...
}
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[tokio::main]
async fn main() -> Result<()> {
let _logging_guard = proxy::logging::init().await?;
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?;
let auth_connections = Arc::new(AuthConnState {
conns: Mutex::new(IndexMap::new()),
});
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let frontend_config = frontent_tls_config()?;
let _frontend_handle = tokio::spawn(start_frontend(
"0.0.0.0:5432".parse()?,
frontend_config,
auth_connections.clone(),
));
quinn_handle.await.unwrap();
Ok(())
}
async fn endpoint_config(addr: SocketAddr) -> Result<Endpoint> {
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
params
.distinguished_name
.push(rcgen::DnType::CommonName, "pglb");
let key = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).context("keygen")?;
params.key_pair = Some(key);
let cert = rcgen::Certificate::from_params(params).context("cert")?;
let cert_der = cert.serialize_der().context("serialize")?;
let key_der = cert.serialize_private_key_der();
let cert = CertificateDer::from(cert_der);
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der));
let config = quinn::ServerConfig::with_single_cert(vec![cert], key).context("server config")?;
Endpoint::server(config, addr).context("endpoint")
}
async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
loop {
let incoming = ep.accept().await.expect("quinn server should not crash");
let state = state.clone();
tokio::spawn(async move {
let conn = incoming.await.unwrap();
let conn_id = conn.stable_id();
println!("[{conn_id:?}] new conn");
state
.conns
.lock()
.unwrap()
.insert(conn_id, AuthConn { conn: conn.clone() });
// heartbeat loop
loop {
match timeout(Duration::from_secs(10), conn.accept_uni()).await {
Ok(Ok(mut heartbeat_stream)) => {
let data = heartbeat_stream.read_to_end(128).await.unwrap();
if data.starts_with(b"shutdown") {
println!("[{conn_id:?}] conn shutdown");
break;
}
// else update latency info
}
Ok(Err(conn_err)) => {
println!("[{conn_id:?}] conn err {conn_err:?}");
break;
}
Err(_) => {
println!("[{conn_id:?}] conn timeout err");
break;
}
}
}
state.conns.lock().unwrap().remove(&conn_id);
let conn_closed = conn.closed().await;
println!("[{conn_id:?}] conn closed {conn_closed:?}");
});
}
}
fn frontent_tls_config() -> Result<TlsConfig> {
let (cert, key) = (
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
PrivateKeyDer::Pkcs8(
rustls_pemfile::pkcs8_private_keys(&mut &*std::fs::read("proxy.key").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
),
);
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?
.into();
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
Ok(TlsConfig {
config,
cert_resolver: Arc::new(cert_resolver),
})
}
async fn start_frontend(
addr: SocketAddr,
tls: TlsConfig,
state: Arc<AuthConnState>,
) -> Result<Infallible> {
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
println!("starting");
let connections = tokio_util::task::task_tracker::TaskTracker::new();
loop {
match listener.accept().await {
Ok((socket, client_addr)) => {
println!("accepted");
let conn = PglbConn::new(&state, &tls)?;
connections.spawn(conn.handle(socket, client_addr));
}
Err(e) => {
error!("connection accept error: {e}");
}
}
}
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}
#[derive(Debug)]
struct PglbConn<S: PglbConnState> {
inner: PglbConnInner,
state: S,
}
#[derive(Debug)]
struct PglbConnInner {
tls_config: TlsConfig,
auth_conns: Arc<AuthConnState>,
}
trait PglbConnState: std::fmt::Debug {}
impl PglbConnState for Start {}
impl PglbConnState for ClientConnect {}
impl PglbConnState for AuthPassthrough {}
impl PglbConnState for ComputeConnect {}
impl PglbConnState for ComputePassthrough {}
impl PglbConnState for End {}
#[derive(Debug)]
struct Start;
impl PglbConn<Start> {
fn new(auth_conns: &Arc<AuthConnState>, tls_config: &TlsConfig) -> Result<Self> {
Ok(PglbConn {
inner: PglbConnInner {
auth_conns: Arc::clone(auth_conns),
tls_config: tls_config.clone(),
},
state: Start,
})
}
async fn handle(
self,
client_stream: TcpStream,
client_addr: SocketAddr,
) -> Result<PglbConn<End>> {
self.handle_start(client_stream, client_addr)
.await?
.handle_client_connect()
.await?
.handle_auth_passthrough()
.await?
.handle_compute_connect()
.await?
.handle_compute_passthrough()
.await
}
}
impl PglbConn<Start> {
async fn handle_start(
self,
mut client_stream: TcpStream,
client_addr: SocketAddr,
) -> Result<PglbConn<ClientConnect>> {
match client_stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
bail!("socket option error: {e}");
}
};
// TODO: HAProxy protocol
let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
bail!("check_for_ssl_request: {e}");
}
};
let (client_stream, payload) = if tls_requested {
println!("starting tls upgrade");
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::EncryptionResponse(true)).unwrap();
client_stream.write_all(&buf).await?;
let (stream, tls_server_end_point, server_name) =
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
Err(e) => {
bail!("tls_upgrade: {e}");
}
};
(
stream,
ConnectionInitiatedPayload {
tls_server_end_point,
server_name,
ip_addr: client_addr.ip(),
},
)
} else {
// TODO: support unsecured connections?
bail!("closing non-TLS connection");
};
println!("tls done");
Ok(PglbConn {
inner: self.inner,
state: ClientConnect {
client_stream,
payload,
},
})
}
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
println!("checking for ssl request");
let mut buf = vec![0u8; 8];
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
bail!("EOF");
}
assert_eq!(buf.len(), 8); // TODO: loop, read more
if buf.len() != 8
|| buf[0..4] != 8u32.to_be_bytes()
|| buf[4..8] != 80877103u32.to_be_bytes()
{
return Ok(false);
}
stream.read_exact(&mut buf).await?;
Ok(true)
}
async fn tls_upgrade(
stream: TcpStream,
tls: TlsConfig,
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
.accept(stream)
.await?;
let conn_info = tls_stream.get_ref().1;
let server_name = conn_info.server_name().map(|s| s.to_string());
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
let alpn = String::from_utf8_lossy(other);
warn!(%alpn, "unexpected ALPN");
bail!("protocol violation");
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(server_name.as_deref())
.ok_or(anyhow!("missing cert"))?;
Ok((tls_stream, tls_server_end_point, server_name))
}
}
#[derive(Debug)]
struct ClientConnect {
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
}
impl PglbConn<ClientConnect> {
async fn handle_client_connect(self) -> Result<PglbConn<AuthPassthrough>> {
let auth_conn = {
let conns = self.inner.auth_conns.conns.lock().unwrap();
if conns.is_empty() {
bail!("no auth proxies avaiable");
}
let mut rng = rand::thread_rng();
conns
.get_index(rng.gen_range(0..conns.len()))
.unwrap()
.1
.clone()
// TODO: check closed?
};
println!("connecting to {}", auth_conn.conn.stable_id());
let (send, recv) = auth_conn.conn.open_bi().await?;
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
auth_stream
.send(proxy::PglbMessage::Control(
proxy::PglbControlMessage::ConnectionInitiated(self.state.payload),
))
.await?;
Ok(PglbConn {
inner: self.inner,
state: AuthPassthrough {
client_stream: Framed::new(
self.state.client_stream,
PgRawCodec {
start_or_ssl_request: true,
},
),
auth_stream,
},
})
}
}
#[derive(Debug)]
struct AuthPassthrough {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
}
impl PglbConn<AuthPassthrough> {
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
let mut client_stream = self.state.client_stream;
let mut auth_stream = self.state.auth_stream;
loop {
select! {
biased;
msg = auth_stream.next() => {
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
client_stream.send(PgRawMessage::Generic { payload }).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
bail!("auth proxy sent unexpected message");
}
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
println!("socket");
return Ok(PglbConn {
inner: self.inner,
state: ComputeConnect {
client_stream,
auth_stream,
compute_socket:socket,
},
});
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = client_stream.next() => {
match msg.context("client disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
auth_stream.send(proxy::PglbMessage::Postgres(
payload
)).await?;
}
}
}
}
}
}
}
#[derive(Debug)]
struct ComputeConnect {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
compute_socket: SocketAddr,
}
impl PglbConn<ComputeConnect> {
async fn handle_compute_connect(self) -> Result<PglbConn<ComputePassthrough>> {
let ComputeConnect {
client_stream,
mut auth_stream,
compute_socket,
} = self.state;
let compute_stream = TcpStream::connect(compute_socket).await?;
compute_stream
.set_nodelay(true)
.context("socket option error")?;
let mut compute_stream = Framed::new(
compute_stream,
PgRawCodec {
start_or_ssl_request: false,
},
);
let mut resps = 4;
loop {
select! {
msg = auth_stream.next() => {
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
compute_stream.send(PgRawMessage::Generic { payload } ).await?;
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
println!("establish");
return Ok(PglbConn {
inner: self.inner,
state: ComputePassthrough {
client_stream,
compute_stream,
},
});
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) |
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = compute_stream.next(), if resps > 0 => {
match msg.context("compute disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
resps -= 1;
auth_stream.send(proxy::PglbMessage::Postgres(
payload
)).await?;
}
}
}
}
}
}
}
#[derive(Debug)]
struct ComputePassthrough {
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
compute_stream: Framed<TcpStream, PgRawCodec>,
}
impl PglbConn<ComputePassthrough> {
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
let ComputePassthrough {
client_stream,
compute_stream,
} = self.state;
let mut client_parts = client_stream.into_parts();
let mut compute_parts = compute_stream.into_parts();
assert!(compute_parts.write_buf.is_empty());
assert!(client_parts.write_buf.is_empty());
client_parts.io.write_all(&compute_parts.read_buf).await?;
compute_parts.io.write_all(&client_parts.read_buf).await?;
drop(client_parts.read_buf);
drop(client_parts.write_buf);
drop(compute_parts.read_buf);
drop(compute_parts.write_buf);
copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?;
Ok(PglbConn {
inner: self.inner,
state: End,
})
}
}
#[derive(Debug)]
struct End;
#[derive(Debug)]
struct PgRawCodec {
start_or_ssl_request: bool,
}
impl tokio_util::codec::Encoder<PgRawMessage> for PgRawCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
item.encode(dst)
}
}
impl tokio_util::codec::Decoder for PgRawCodec {
type Item = PgRawMessage;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.start_or_ssl_request {
match PgRawMessage::decode(src, true)? {
msg @ Some(PgRawMessage::Start(..)) => {
self.start_or_ssl_request = false;
Ok(msg)
}
msg @ Some(PgRawMessage::SslRequest) => Ok(msg),
Some(PgRawMessage::Generic { .. }) => unreachable!(),
None => Ok(None),
}
} else {
match PgRawMessage::decode(src, false)? {
Some(PgRawMessage::Start(..)) => unreachable!(),
Some(PgRawMessage::SslRequest) => unreachable!(),
msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg),
None => Ok(None),
}
}
}
}
#[derive(Debug)]
pub enum PgRawMessage {
SslRequest,
Start(BytesMut),
Generic { payload: BytesMut },
}
impl PgRawMessage {
fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> {
match self {
Self::SslRequest => {
dst.put_u32(8);
dst.put_u32(80877103);
}
Self::Start(payload) | Self::Generic { payload } => {
dst.put_slice(payload);
}
}
Ok(())
}
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
let extra = if start { 0 } else { 1 };
if src.remaining() < 4 + extra {
src.reserve(4 + extra);
return Ok(None);
}
let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
Ok(Some(PgRawMessage::SslRequest))
} else if start {
Ok(Some(PgRawMessage::Start(src.split_to(length))))
} else {
Ok(Some(PgRawMessage::Generic {
payload: src.split_to(length),
}))
}
}
}

View File

@@ -224,6 +224,10 @@ struct ProxyCliArgs {
/// Whether to retry the wake_compute request
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
wake_compute_retry: String,
/// Configure if this is a private access proxy for the POC: In that case the proxy will ignore the IP allowlist
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_private_access_proxy: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -682,6 +686,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
};
let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -13,7 +13,6 @@ use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
@@ -65,6 +64,7 @@ pub struct AuthenticationConfig {
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
}
impl TlsConfig {
@@ -150,7 +150,7 @@ pub fn configure_tls(
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,

View File

@@ -125,7 +125,7 @@ impl RequestMonitoring {
Self(TryLock::new(inner))
}
// #[cfg(test)]
#[cfg(test)]
pub(crate) fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
}

View File

@@ -82,23 +82,15 @@
impl_trait_overcaptures,
)]
use std::{
convert::Infallible,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
use std::{convert::Infallible, future::Future};
use anyhow::{bail, Context};
use bytes::{Buf, BufMut};
use config::TlsServerEndPoint;
use intern::{EndpointIdInt, EndpointIdTag, InternId};
use serde::{Deserialize, Serialize};
use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::warn;
pub mod auth;
pub mod auth_proxy;
pub mod cache;
pub mod cancellation;
pub mod compute;
@@ -282,132 +274,3 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
#[derive(Debug)]
pub struct PglbCodec;
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
match item {
PglbMessage::Control(ctrl) => {
dst.put_u8(1);
match ctrl {
PglbControlMessage::ConnectionInitiated(msg) => {
let encode = serde_json::to_string(&msg).context("ser")?;
dst.put_u32(1 + encode.len() as u32);
dst.put_u8(0);
dst.put(encode.as_bytes());
}
PglbControlMessage::ConnectToCompute { socket } => match socket {
SocketAddr::V4(v4) => {
dst.put_u32(1 + 4 + 2);
dst.put_u8(1);
dst.put_u32(v4.ip().to_bits());
dst.put_u16(v4.port());
}
SocketAddr::V6(v6) => {
dst.put_u32(1 + 16 + 2);
dst.put_u8(1);
dst.put_u128(v6.ip().to_bits());
dst.put_u16(v6.port());
}
},
PglbControlMessage::ComputeEstablish => {
dst.put_u32(1);
dst.put_u8(2);
}
}
}
PglbMessage::Postgres(pg) => {
dst.put_u8(0);
dst.put_u32(pg.len() as u32);
dst.put(pg);
}
}
Ok(())
}
}
impl tokio_util::codec::Decoder for PglbCodec {
type Item = PglbMessage;
type Error = anyhow::Error;
fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if dst.remaining() < 5 {
dst.reserve(5);
return Ok(None);
}
let msg = dst[0];
let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize;
if len + 5 > dst.remaining() {
dst.reserve(len + 5);
return Ok(None);
}
dst.advance(5);
let mut payload = dst.split_to(len);
match msg {
// postgres
0 => Ok(Some(PglbMessage::Postgres(payload))),
// control
1 => {
if payload.is_empty() {
bail!("invalid ctrl message")
}
let ctrl_msg = payload.split_to(1)[0];
let ctrl_msg = match ctrl_msg {
0 => PglbControlMessage::ConnectionInitiated(
serde_json::from_slice(&payload).context("deser")?,
),
// ipv4 socket
1 if len == 7 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())),
payload.get_u16(),
),
},
// ipv6 socket
1 if len == 19 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())),
payload.get_u16(),
),
},
2 if len == 1 => PglbControlMessage::ComputeEstablish,
_ => bail!("invalid ctrl message"),
};
Ok(Some(PglbMessage::Control(ctrl_msg)))
}
_ => bail!("invalid message"),
}
}
}
pub enum PglbMessage {
Control(PglbControlMessage),
Postgres(bytes::BytesMut),
}
pub enum PglbControlMessage {
// from pglb to auth proxy
ConnectionInitiated(ConnectionInitiatedPayload),
// from auth proxy to pglb
ConnectToCompute { socket: SocketAddr },
// from auth proxy to pglb
ComputeEstablish,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ConnectionInitiatedPayload {
pub tls_server_end_point: TlsServerEndPoint,
pub server_name: Option<String>,
pub ip_addr: IpAddr,
}

View File

@@ -7,35 +7,9 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use anyhow::bail;
use anyhow::Context;
use bytes::BytesMut;
use connect_compute::ComputeConnectBackend;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use futures::SinkExt;
use futures::TryStreamExt;
use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ChannelBinding;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use pq_proto::FeStartupPacket;
use quinn::RecvStream;
use quinn::SendStream;
use tokio::io::join;
use tokio_postgres::config::AuthKeys;
use tokio_util::codec::Framed;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeCredentials;
use crate::auth_proxy::AuthProxyStream;
use crate::auth_proxy::TLS_SERVER_END_POINT;
use crate::console::NodeInfo;
use crate::stream::AuthProxyStreamExt;
use crate::ConnectionInitiatedPayload;
use crate::PglbControlMessage;
use crate::PglbMessage;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -56,8 +30,6 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -405,10 +377,10 @@ async fn prepare_client_connection<P>(
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
impl NeonOptions {
pub fn parse_params(params: &StartupMessageParams) -> Self {
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
params
.options_raw()
.map(Self::parse_from_iter)
@@ -459,157 +431,3 @@ pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
let (_, [k, v]) = cap.extract();
Some((k, v))
}
pub struct AuthProxyConfig {
pub backend: crate::auth_proxy::Backend<'static, ()>,
pub auth: crate::config::AuthenticationConfig,
}
pub async fn handle_stream(
config: &'static AuthProxyConfig,
send: SendStream,
recv: RecvStream,
) -> anyhow::Result<()> {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
// recv connection metadata
let first_msg = stream.try_next().await?;
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg
else {
panic!("invalid first msg")
};
println!("new conn: {conn_info:?}");
// read startup packet
let startup = stream.read_startup_packet().await?;
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
panic!("invalid startup message")
};
println!("params: {params:?}");
let user_info = auth_with_user(&mut stream, config, &conn_info, &params).await?;
println!("authenticated");
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
println!("woke compute");
let addr: IpAddr = node_info.config.get_host()?.parse()?;
let socket = SocketAddr::new(addr, node_info.config.get_ports()[0]);
// tell pglb that the compute is up
stream
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
socket,
}))
.await?;
// send startup message to compute
let mut buf = BytesMut::new();
frontend::startup_message(params.iter(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
auth_with_compute(&mut stream, user_info.get_keys()).await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
.await?;
Ok(())
}
async fn auth_with_user(
stream: &mut AuthProxyStream,
config: &'static AuthProxyConfig,
conn_info: &ConnectionInitiatedPayload,
params: &StartupMessageParams,
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
dbg!("auth...");
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").context("missing user")?.into(),
endpoint_id: conn_info
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(params),
};
dbg!("parsed used info");
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
.scope(
conn_info.tls_server_end_point,
user_info.authenticate(stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
Ok(user_info)
}
async fn auth_with_compute(
stream: &mut AuthProxyStream,
keys: &ComputeCredentialKeys,
) -> anyhow::Result<()> {
let ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(scram_keys)) = keys else {
bail!("missing keys");
};
// compute offers sasl
stream
.read_backend_message(|m| match m {
Message::AuthenticationSasl(_body) => Ok(()),
_ => bail!("invalid message"),
})
.await?;
let mut buf = BytesMut::new();
// send auth message
let mut scram = ScramSha256::new_with_keys(*scram_keys, ChannelBinding::unsupported());
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let cont_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslContinue(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.update(cont_body.data()).await?;
frontend::sasl_response(scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let final_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslFinal(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.finish(final_body.data())?;
// wait for ok.
stream
.read_backend_message(|m| match m {
Message::AuthenticationOk => Ok(()),
_ => bail!("invalid message"),
})
.await?;
Ok(())
}

View File

@@ -9,7 +9,6 @@
mod channel_binding;
mod messages;
mod stream;
mod stream2;
use crate::error::{ReportableError, UserFacingError};
use std::io;
@@ -18,7 +17,6 @@ use thiserror::Error;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
pub(crate) use stream2::SaslStream2;
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]

View File

@@ -1,85 +0,0 @@
//! Abstraction for the string-oriented SASL protocols.
use crate::{
auth_proxy::AuthProxyStream,
sasl::{messages::ServerMessage, Mechanism},
stream::AuthProxyStreamExt,
};
use std::io;
use tracing::info;
use super::Outcome;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream2<'a> {
/// The underlying stream.
stream: &'a mut AuthProxyStream,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a> SaslStream2<'a> {
pub(crate) fn new(stream: &'a mut AuthProxyStream, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl SaslStream2<'_> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl SaslStream2<'_> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
}
impl SaslStream2<'_> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub(crate) async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> crate::sasl::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
use crate::sasl::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send(&ServerMessage::Final(&reply)).await?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
}
}
}

View File

@@ -50,7 +50,9 @@ impl PoolingBackend {
.as_ref()
.map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
if config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
if !self

View File

@@ -1,13 +1,8 @@
use crate::auth_proxy::AuthProxyStream;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::PglbMessage;
use anyhow::{bail, Context};
use bytes::BytesMut;
use futures::{SinkExt, TryStreamExt};
use postgres_protocol::message::backend;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -299,140 +294,3 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
#[allow(async_fn_in_trait)]
pub(crate) trait AuthProxyStreamExt {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>;
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket>;
async fn read_message(&mut self) -> io::Result<FeMessage>;
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T>;
}
impl AuthProxyStreamExt for AuthProxyStream {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
let mut b = BytesMut::new();
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.start_send_unpin(PglbMessage::Postgres(b))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
}
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeStartupPacket::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeMessage::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
)),
}
}
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T> {
let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else {
bail!("invalid message");
};
let message = backend::Message::parse(&mut buf)?.context("missing")?;
f(message)
}
}

View File

@@ -1,11 +1,11 @@
{
"v16": [
"16.4",
"6e9a4ff6249ac02b8175054b7b3f7dfb198be48b"
"0baa7346dfd42d61912eeca554c9bb0a190f0a1e"
],
"v15": [
"15.8",
"49d5e576a56e4cc59cd6a6a0791b2324b9fa675e"
"6f6d77fb5960602fcd3fd130aca9f99ecb1619c9"
],
"v14": [
"14.13",