mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 02:42:56 +00:00
Compare commits
5 Commits
hack/compu
...
proxy-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fcf0f2754 | ||
|
|
d336b8b5d9 | ||
|
|
4d68e3108f | ||
|
|
3e150419ef | ||
|
|
9e424d2f84 |
@@ -16,7 +16,7 @@ use crate::{
|
||||
use once_cell::sync::Lazy;
|
||||
use std::borrow::Cow;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, instrument, warn};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
@@ -143,6 +143,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
|
||||
use BackendType::*;
|
||||
|
||||
@@ -190,7 +191,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
|
||||
Console(endpoint, creds) if use_cleartext_password_flow => {
|
||||
// This is a hack to allow cleartext password in secure connections (wss).
|
||||
let payload = fetch_plaintext_password(client).await?;
|
||||
let creds = creds.as_ref();
|
||||
@@ -220,16 +221,25 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
///
|
||||
/// If `use_cleartext_password_flow` is true, we use the old cleartext password
|
||||
/// flow. It is used for websocket connections, which want to minimize the number
|
||||
/// of round trips.
|
||||
#[instrument(skip_all)]
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
// Handle cases when `project` is missing in `creds`.
|
||||
// TODO: type safety: return `creds` with irrefutable `project`.
|
||||
if let Some(res) = self.try_password_hack(extra, client).await? {
|
||||
if let Some(res) = self
|
||||
.try_password_hack(extra, client, use_cleartext_password_flow)
|
||||
.await?
|
||||
{
|
||||
info!("user successfully authenticated (using the password hack)");
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
@@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
/// If `True`, we'll use the old cleartext password flow. This is used for
|
||||
/// websocket connections, which want to minimize the number of round trips.
|
||||
pub use_cleartext_password_flow: bool,
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
@@ -53,7 +50,6 @@ impl<'a> ClientCredentials<'a> {
|
||||
user: self.user,
|
||||
dbname: self.dbname,
|
||||
project: self.project().map(Cow::Borrowed),
|
||||
use_cleartext_password_flow: self.use_cleartext_password_flow,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,7 +59,6 @@ impl<'a> ClientCredentials<'a> {
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
@@ -113,7 +108,6 @@ impl<'a> ClientCredentials<'a> {
|
||||
user = user,
|
||||
dbname = dbname,
|
||||
project = project.as_deref(),
|
||||
use_cleartext_password_flow = use_cleartext_password_flow,
|
||||
"credentials"
|
||||
);
|
||||
|
||||
@@ -121,7 +115,6 @@ impl<'a> ClientCredentials<'a> {
|
||||
user,
|
||||
dbname,
|
||||
project,
|
||||
use_cleartext_password_flow,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -148,7 +141,7 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
@@ -158,7 +151,7 @@ mod tests {
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -173,7 +166,7 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -189,7 +182,7 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -208,7 +201,7 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -227,8 +220,7 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -245,8 +237,7 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_name = Some("example.com");
|
||||
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentSni { sni, cn } => {
|
||||
assert_eq!(sni, "project.localhost");
|
||||
|
||||
@@ -25,12 +25,11 @@ impl CancelMap {
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session<'a>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
/// Create a new session, with a new client-facing random cancellation key.
|
||||
///
|
||||
/// Use `enable_query_cancellation` to register a database cancellation
|
||||
/// key with it, and to get the client-facing key.
|
||||
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
@@ -44,17 +43,9 @@ impl CancelMap {
|
||||
.lock()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.lock().remove(&key);
|
||||
info!("dropped query cancellation key {key}");
|
||||
}
|
||||
|
||||
info!("registered new query cancellation key {key}");
|
||||
let session = Session::new(key, self);
|
||||
f(session).await
|
||||
|
||||
Ok(Session::new(key, self))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
|
||||
impl Session<'_> {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancel_map
|
||||
.0
|
||||
@@ -122,6 +113,14 @@ impl Session<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for Session<'a> {
|
||||
fn drop(&mut self) {
|
||||
let key = &self.key;
|
||||
self.cancel_map.0.lock().remove(key);
|
||||
info!("dropped query cancellation key {key}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -132,14 +131,14 @@ mod tests {
|
||||
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
|
||||
|
||||
let session = CANCEL_MAP.new_session()?;
|
||||
let task = tokio::spawn(async move {
|
||||
assert!(CANCEL_MAP.contains(&session));
|
||||
|
||||
tx.send(()).expect("failed to send");
|
||||
futures::future::pending::<()>().await; // sleep forever
|
||||
|
||||
Ok(())
|
||||
}));
|
||||
});
|
||||
|
||||
// Wait until the task has been spawned.
|
||||
rx.await.context("failed to hear from the task")?;
|
||||
|
||||
@@ -14,7 +14,7 @@ use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
use tracing::{error, info, info_span, instrument, Instrument};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||
@@ -71,17 +71,35 @@ pub async fn task_main(
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(config, &cancel_map, session_id, socket).await
|
||||
handle_postgres_client(config, &cancel_map, session_id, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
})
|
||||
.instrument(info_span!("client", session = format_args!("{session_id}"))),
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming PostgreSQL connection
|
||||
pub async fn handle_postgres_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
handle_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
stream,
|
||||
HostnameMethod::Sni,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Handle an incoming Postgres connection that's wrapped in websocket
|
||||
pub async fn handle_ws_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
@@ -89,45 +107,32 @@ pub async fn handle_ws_client(
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let hostname = hostname.as_deref();
|
||||
|
||||
// TLS is None here, because the connection is already encrypted.
|
||||
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session))
|
||||
.await
|
||||
handle_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
stream,
|
||||
HostnameMethod::Param(hostname),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
enum HostnameMethod {
|
||||
Param(Option<String>),
|
||||
Sni,
|
||||
}
|
||||
|
||||
/// Handle an incoming client connection, handshake and authentication.
|
||||
/// After that, keeps forwarding all the data. This doesn't return until the
|
||||
/// connection is lost.
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
raw_stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
hostname_method: HostnameMethod,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
@@ -135,36 +140,73 @@ async fn handle_client(
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let do_handshake = handshake(stream, tls, cancel_map).instrument(info_span!("handshake"));
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
// Accept the connection from the client, authenticate it, and establish
|
||||
// connection to the database.
|
||||
//
|
||||
// We cover all these activities in a one tracing span, so that they are
|
||||
// traced as one request. That makes it convenient to investigate where
|
||||
// the time is spent when establishing a new connection. After the
|
||||
// connection has been established, we exit the span, and use a separate
|
||||
// span for the (rest of the) duration of the connection.
|
||||
let conn = async {
|
||||
// Process postgres startup packet and upgrade to TLS (if applicable)
|
||||
let tls = config.tls_config.as_ref();
|
||||
let (mut stream, params) = match handshake(raw_stream, tls, cancel_map).await? {
|
||||
Some(x) => x,
|
||||
None => return Ok::<_, anyhow::Error>(None), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false))
|
||||
.transpose();
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = match &hostname_method {
|
||||
HostnameMethod::Param(hostname) => hostname.as_deref(),
|
||||
HostnameMethod::Sni => stream.get_ref().sni_hostname(),
|
||||
};
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session))
|
||||
.await
|
||||
Ok(Some(
|
||||
EstablishedConnection::connect_to_db(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
use_cleartext_password_flow,
|
||||
cancel_map,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
}
|
||||
.instrument(info_span!("establish_connection", session_id=%session_id))
|
||||
.await?;
|
||||
|
||||
match conn {
|
||||
Some(conn) => {
|
||||
// Connection established in both ways. Forward all traffic until the
|
||||
// either connection is lost.
|
||||
conn.handle_connection()
|
||||
.instrument(info_span!("forward", session_id=%session_id))
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
// It was a cancellation request. It was handled in 'handshake' already.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[instrument(skip_all)]
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
@@ -227,43 +269,36 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
struct EstablishedConnection<'a, S> {
|
||||
client_stream: MeasuredStream<S>,
|
||||
db_stream: MeasuredStream<tokio::net::TcpStream>,
|
||||
|
||||
/// Hold on to the Session for as long as the connection is alive, so that
|
||||
/// it can be cancelled.
|
||||
_session: cancellation::Session<'a>,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
}
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> EstablishedConnection<'_, S> {
|
||||
async fn handle_connection(mut self) -> anyhow::Result<()> {
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut self.client_stream, &mut self.db_stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
} = self;
|
||||
/// On return, the connection is fully established in both ways, and we can start
|
||||
/// forwarding the bytes.
|
||||
#[instrument(skip_all)]
|
||||
async fn connect_to_db<'a>(
|
||||
mut stream: PqStream<S>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'_>>,
|
||||
params: &'_ StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
use_cleartext_password_flow: bool,
|
||||
cancel_map: &'a CancelMap,
|
||||
) -> anyhow::Result<EstablishedConnection<'a, S>> {
|
||||
let session = cancel_map.new_session()?;
|
||||
|
||||
let extra = auth::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
@@ -272,10 +307,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
|
||||
let auth_result = async {
|
||||
// `&mut stream` doesn't let us merge those 2 lines.
|
||||
let res = creds.authenticate(&extra, &mut stream).await;
|
||||
let res = creds
|
||||
.authenticate(&extra, &mut stream, use_cleartext_password_flow)
|
||||
.await;
|
||||
async { res }.or_else(|e| stream.throw_error(e)).await
|
||||
}
|
||||
.instrument(info_span!("auth"))
|
||||
.await?;
|
||||
|
||||
let node = auth_result.value;
|
||||
@@ -311,21 +347,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
.await?;
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
});
|
||||
let client_stream = MeasuredStream::new(stream.into_inner(), m_sent);
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
|
||||
let mut db = MeasuredStream::new(db.stream, |cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
});
|
||||
let db_stream = MeasuredStream::new(db.stream, m_recv);
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
|
||||
Ok(())
|
||||
Ok(EstablishedConnection {
|
||||
client_stream,
|
||||
db_stream,
|
||||
_session: session,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::{io, task};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing::instrument;
|
||||
|
||||
pin_project! {
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
@@ -105,6 +106,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
#[instrument(skip_all)]
|
||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||
tracing::info!("forwarding error to user: {error}");
|
||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||
@@ -114,6 +116,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
#[instrument(skip_all)]
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
@@ -228,27 +231,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MeasuredStream<S, W> {
|
||||
/// This stream tracks all writes, and whenever the stream is flushed,
|
||||
/// increments the user-provided counter by the number of bytes flushed.
|
||||
pub struct MeasuredStream<S> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_write_count: W,
|
||||
write_counter: prometheus::IntCounter,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, W> MeasuredStream<S, W> {
|
||||
pub fn new(stream: S, inc_write_count: W) -> Self {
|
||||
impl<S> MeasuredStream<S> {
|
||||
pub fn new(stream: S, write_counter: prometheus::IntCounter) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_write_count,
|
||||
write_counter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for MeasuredStream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
@@ -258,7 +261,7 @@ impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
|
||||
impl<S: AsyncWrite + Unpin> AsyncWrite for MeasuredStream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
@@ -279,7 +282,7 @@ impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W>
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
this.write_counter.inc_by(*this.write_count as u64);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user