mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
flatten proxy flow (#6447)
## Problem Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and doing a bit less radical changes. smaller commits. Proxy flow was quite deeply nested, which makes adding more interesting error handling quite tricky. ## Summary of changes I recommend reviewing commit by commit. 1. move handshake logic into a separate file 2. move passthrough logic into a separate file 3. no longer accept a closure in CancelMap session logic 4. Remove connect_to_db, copy logic into handle_client 5. flatten auth_and_wake_compute in authenticate 6. record info for link auth
This commit is contained in:
@@ -190,7 +190,10 @@ async fn auth_quirks(
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
|
||||
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
||||
|
||||
(res.info, Some(res.keys))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
@@ -271,19 +274,12 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||
}
|
||||
|
||||
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
|
||||
/// only if authentication was successfuly.
|
||||
async fn auth_and_wake_compute(
|
||||
/// wake a compute (or retrieve an existing compute session from cache)
|
||||
async fn wake_compute(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
compute_credentials: ComputeCredentials<ComputeCredentialKeys>,
|
||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
||||
let compute_credentials =
|
||||
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
|
||||
@@ -358,16 +354,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
let compute_credentials =
|
||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||
let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let node_info = link::authenticate(&url, client).await?;
|
||||
let node_info = link::authenticate(ctx, &url, client).await?;
|
||||
|
||||
(
|
||||
CachedNodeInfo::new_uncached(node_info),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
@@ -54,6 +55,7 @@ pub fn new_psql_session_id() -> String {
|
||||
}
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &mut RequestMonitoring,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
@@ -94,6 +96,10 @@ pub(super) async fn authenticate(
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
ctx.set_user(db_info.user.into());
|
||||
ctx.set_project(db_info.aux.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(&db_info.aux.endpoint_id));
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
|
||||
@@ -126,7 +126,11 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
if let Some(ep) = &endpoint {
|
||||
ctx.set_endpoint_id(ep.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(ep));
|
||||
}
|
||||
|
||||
info!(%user, project = endpoint.as_deref(), "credentials");
|
||||
if sni.is_some() {
|
||||
@@ -150,7 +154,7 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
endpoint_id: endpoint.map(EndpointId::from),
|
||||
endpoint_id: endpoint,
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -272,5 +272,5 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::{bail, Context};
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::net::SocketAddr;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
@@ -25,39 +25,31 @@ impl CancelMap {
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session<'a>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
pub fn get_session(self: Arc<Self>) -> Session {
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = rand::random();
|
||||
let key = loop {
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.0.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => {
|
||||
bail!("query cancellation key already exists: {key}")
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.0.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
}
|
||||
}
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
}
|
||||
}
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.remove(&key);
|
||||
info!("dropped query cancellation key {key}");
|
||||
}
|
||||
break key;
|
||||
};
|
||||
|
||||
info!("registered new query cancellation key {key}");
|
||||
let session = Session::new(key, self);
|
||||
f(session).await
|
||||
Session {
|
||||
key,
|
||||
cancel_map: self,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -98,23 +90,17 @@ impl CancelClosure {
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session<'a> {
|
||||
pub struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: &'a CancelMap,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
}
|
||||
|
||||
impl<'a> Session<'a> {
|
||||
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
|
||||
Self { key, cancel_map }
|
||||
}
|
||||
}
|
||||
|
||||
impl Session<'_> {
|
||||
impl Session {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
||||
|
||||
@@ -122,37 +108,26 @@ impl Session<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_map.0.remove(&self.key);
|
||||
info!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
|
||||
assert!(CANCEL_MAP.contains(&session));
|
||||
|
||||
tx.send(()).expect("failed to send");
|
||||
futures::future::pending::<()>().await; // sleep forever
|
||||
|
||||
Ok(())
|
||||
}));
|
||||
|
||||
// Wait until the task has been spawned.
|
||||
rx.await.context("failed to hear from the task")?;
|
||||
|
||||
// Drop the session's entry by cancelling the task.
|
||||
task.abort();
|
||||
let error = task.await.expect_err("task should have failed");
|
||||
if !error.is_cancelled() {
|
||||
anyhow::bail!(error);
|
||||
}
|
||||
let cancel_map: Arc<CancelMap> = Default::default();
|
||||
|
||||
let session = cancel_map.clone().get_session();
|
||||
assert!(cancel_map.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(CANCEL_MAP.is_empty());
|
||||
assert!(cancel_map.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -89,13 +89,11 @@ impl RequestMonitoring {
|
||||
self.project = Some(x.project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<EndpointId>) {
|
||||
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
|
||||
if let Some(ep) = &self.endpoint_id {
|
||||
crate::metrics::CONNECTING_ENDPOINTS
|
||||
.with_label_values(&[self.protocol])
|
||||
.measure(&ep);
|
||||
}
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
|
||||
crate::metrics::CONNECTING_ENDPOINTS
|
||||
.with_label_values(&[self.protocol])
|
||||
.measure(&endpoint_id);
|
||||
self.endpoint_id = Some(endpoint_id);
|
||||
}
|
||||
|
||||
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
|
||||
@@ -2,37 +2,34 @@
|
||||
mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
pub mod handshake;
|
||||
pub mod passthrough;
|
||||
pub mod retry;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancelMap},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
||||
console::messages::MetricsAuxInfo,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
metrics::{
|
||||
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
|
||||
},
|
||||
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
|
||||
protocol2::WithClientIp,
|
||||
proxy::{handshake::handshake, passthrough::proxy_pass},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
|
||||
@@ -80,6 +77,13 @@ pub async fn task_main(
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let session_span = info_span!(
|
||||
"handle_client",
|
||||
?session_id,
|
||||
peer_addr = tracing::field::Empty,
|
||||
ep = tracing::field::Empty,
|
||||
);
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
info!("accepted postgres client connection");
|
||||
@@ -103,22 +107,18 @@ pub async fn task_main(
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
&cancel_map,
|
||||
cancel_map,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(info_span!(
|
||||
"handle_client",
|
||||
?session_id,
|
||||
peer_addr = tracing::field::Empty
|
||||
))
|
||||
.unwrap_or_else(move |e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
error!(?session_id, "per-client task finished with an error: {e:#}");
|
||||
}),
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
})
|
||||
.instrument(session_span),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ impl ClientMode {
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancel_map: &CancelMap,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -192,138 +192,88 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let pause = ctx.latency_timer.pause();
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = {
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| {
|
||||
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)
|
||||
})
|
||||
.transpose();
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
// check rate limit
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (mut node_info, user_info) = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream.throw_error(e).instrument(params_span).await;
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_endpoint_id(user_info.get_endpoint());
|
||||
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
user_info,
|
||||
¶ms,
|
||||
mode.allow_self_signed_compute(config),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
cancel_map
|
||||
.with_session(|session| {
|
||||
client.connect_to_db(ctx, session, mode, &config.authentication_config)
|
||||
})
|
||||
.await
|
||||
}
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism { params: ¶ms },
|
||||
node_info,
|
||||
&user_info,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
let session = cancel_map.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
info!("received {msg:?}");
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empy.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
info!(session_type = "normal", "successful handshake");
|
||||
break Ok(Some((stream, params)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
session: cancellation::Session<'_>,
|
||||
session: &cancellation::Session,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
@@ -349,151 +299,6 @@ async fn prepare_client_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
});
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
||||
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
m_sent2.inc_by(cnt as u64);
|
||||
usage.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
|
||||
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
m_recv2.inc_by(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
/// Rate limiter for endpoints
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<Stream<S>>,
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
params: &'a StartupMessageParams,
|
||||
allow_self_signed_compute: bool,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
user_info,
|
||||
params,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
||||
// auth; strictly speaking we don't know endpoint name in its case.
|
||||
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
session: cancellation::Session<'_>,
|
||||
mode: ClientMode,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
user_info,
|
||||
params,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
} = self;
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let auth_result = match user_info
|
||||
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream.throw_error(e).instrument(params_span).await;
|
||||
}
|
||||
};
|
||||
|
||||
let (mut node_info, user_info) = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, session, &mut stream).await?;
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
|
||||
96
proxy/src/proxy/handshake.rs
Normal file
96
proxy/src/proxy/handshake.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use anyhow::{bail, Context};
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::TlsConfig,
|
||||
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
info!("received {msg:?}");
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empy.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
info!(session_type = "normal", "successful handshake");
|
||||
break Ok(Some((stream, params)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
57
proxy/src/proxy/passthrough.rs
Normal file
57
proxy/src/proxy/passthrough.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::{
|
||||
console::messages::MetricsAuxInfo,
|
||||
context::RequestMonitoring,
|
||||
metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER},
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
});
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
||||
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
m_sent2.inc_by(cnt as u64);
|
||||
usage.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
|
||||
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
m_recv2.inc_by(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -230,7 +230,7 @@ async fn request_handler(
|
||||
config,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
&cancel_map,
|
||||
cancel_map,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
|
||||
@@ -189,7 +189,7 @@ fn get_conn_info(
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: &CancelMap,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
hostname: Option<String>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user