mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-04 11:10:37 +00:00
Compare commits
5 Commits
remove_ini
...
proxy-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fcf0f2754 | ||
|
|
d336b8b5d9 | ||
|
|
4d68e3108f | ||
|
|
3e150419ef | ||
|
|
9e424d2f84 |
@@ -16,7 +16,7 @@ use crate::{
|
|||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, instrument, warn};
|
||||||
|
|
||||||
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
|
static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
|
||||||
|
|
||||||
@@ -143,6 +143,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
extra: &ConsoleReqExtra<'_>,
|
extra: &ConsoleReqExtra<'_>,
|
||||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||||
|
use_cleartext_password_flow: bool,
|
||||||
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
|
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
|
|
||||||
@@ -190,7 +191,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
|||||||
|
|
||||||
(node, payload)
|
(node, payload)
|
||||||
}
|
}
|
||||||
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
|
Console(endpoint, creds) if use_cleartext_password_flow => {
|
||||||
// This is a hack to allow cleartext password in secure connections (wss).
|
// This is a hack to allow cleartext password in secure connections (wss).
|
||||||
let payload = fetch_plaintext_password(client).await?;
|
let payload = fetch_plaintext_password(client).await?;
|
||||||
let creds = creds.as_ref();
|
let creds = creds.as_ref();
|
||||||
@@ -220,16 +221,25 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||||
|
///
|
||||||
|
/// If `use_cleartext_password_flow` is true, we use the old cleartext password
|
||||||
|
/// flow. It is used for websocket connections, which want to minimize the number
|
||||||
|
/// of round trips.
|
||||||
|
#[instrument(skip_all)]
|
||||||
pub async fn authenticate(
|
pub async fn authenticate(
|
||||||
mut self,
|
mut self,
|
||||||
extra: &ConsoleReqExtra<'_>,
|
extra: &ConsoleReqExtra<'_>,
|
||||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||||
|
use_cleartext_password_flow: bool,
|
||||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||||
use BackendType::*;
|
use BackendType::*;
|
||||||
|
|
||||||
// Handle cases when `project` is missing in `creds`.
|
// Handle cases when `project` is missing in `creds`.
|
||||||
// TODO: type safety: return `creds` with irrefutable `project`.
|
// TODO: type safety: return `creds` with irrefutable `project`.
|
||||||
if let Some(res) = self.try_password_hack(extra, client).await? {
|
if let Some(res) = self
|
||||||
|
.try_password_hack(extra, client, use_cleartext_password_flow)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
info!("user successfully authenticated (using the password hack)");
|
info!("user successfully authenticated (using the password hack)");
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> {
|
|||||||
pub user: &'a str,
|
pub user: &'a str,
|
||||||
pub dbname: &'a str,
|
pub dbname: &'a str,
|
||||||
pub project: Option<Cow<'a, str>>,
|
pub project: Option<Cow<'a, str>>,
|
||||||
/// If `True`, we'll use the old cleartext password flow. This is used for
|
|
||||||
/// websocket connections, which want to minimize the number of round trips.
|
|
||||||
pub use_cleartext_password_flow: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientCredentials<'_> {
|
impl ClientCredentials<'_> {
|
||||||
@@ -53,7 +50,6 @@ impl<'a> ClientCredentials<'a> {
|
|||||||
user: self.user,
|
user: self.user,
|
||||||
dbname: self.dbname,
|
dbname: self.dbname,
|
||||||
project: self.project().map(Cow::Borrowed),
|
project: self.project().map(Cow::Borrowed),
|
||||||
use_cleartext_password_flow: self.use_cleartext_password_flow,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,7 +59,6 @@ impl<'a> ClientCredentials<'a> {
|
|||||||
params: &'a StartupMessageParams,
|
params: &'a StartupMessageParams,
|
||||||
sni: Option<&str>,
|
sni: Option<&str>,
|
||||||
common_name: Option<&str>,
|
common_name: Option<&str>,
|
||||||
use_cleartext_password_flow: bool,
|
|
||||||
) -> Result<Self, ClientCredsParseError> {
|
) -> Result<Self, ClientCredsParseError> {
|
||||||
use ClientCredsParseError::*;
|
use ClientCredsParseError::*;
|
||||||
|
|
||||||
@@ -113,7 +108,6 @@ impl<'a> ClientCredentials<'a> {
|
|||||||
user = user,
|
user = user,
|
||||||
dbname = dbname,
|
dbname = dbname,
|
||||||
project = project.as_deref(),
|
project = project.as_deref(),
|
||||||
use_cleartext_password_flow = use_cleartext_password_flow,
|
|
||||||
"credentials"
|
"credentials"
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -121,7 +115,6 @@ impl<'a> ClientCredentials<'a> {
|
|||||||
user,
|
user,
|
||||||
dbname,
|
dbname,
|
||||||
project,
|
project,
|
||||||
use_cleartext_password_flow,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +141,7 @@ mod tests {
|
|||||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||||
|
|
||||||
// TODO: check that `creds.dbname` is None.
|
// TODO: check that `creds.dbname` is None.
|
||||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||||
assert_eq!(creds.user, "john_doe");
|
assert_eq!(creds.user, "john_doe");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -158,7 +151,7 @@ mod tests {
|
|||||||
fn parse_missing_project() -> anyhow::Result<()> {
|
fn parse_missing_project() -> anyhow::Result<()> {
|
||||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||||
|
|
||||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||||
assert_eq!(creds.user, "john_doe");
|
assert_eq!(creds.user, "john_doe");
|
||||||
assert_eq!(creds.dbname, "world");
|
assert_eq!(creds.dbname, "world");
|
||||||
assert_eq!(creds.project, None);
|
assert_eq!(creds.project, None);
|
||||||
@@ -173,7 +166,7 @@ mod tests {
|
|||||||
let sni = Some("foo.localhost");
|
let sni = Some("foo.localhost");
|
||||||
let common_name = Some("localhost");
|
let common_name = Some("localhost");
|
||||||
|
|
||||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||||
assert_eq!(creds.user, "john_doe");
|
assert_eq!(creds.user, "john_doe");
|
||||||
assert_eq!(creds.dbname, "world");
|
assert_eq!(creds.dbname, "world");
|
||||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||||
@@ -189,7 +182,7 @@ mod tests {
|
|||||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||||
assert_eq!(creds.user, "john_doe");
|
assert_eq!(creds.user, "john_doe");
|
||||||
assert_eq!(creds.dbname, "world");
|
assert_eq!(creds.dbname, "world");
|
||||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||||
@@ -208,7 +201,7 @@ mod tests {
|
|||||||
let sni = Some("baz.localhost");
|
let sni = Some("baz.localhost");
|
||||||
let common_name = Some("localhost");
|
let common_name = Some("localhost");
|
||||||
|
|
||||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||||
assert_eq!(creds.user, "john_doe");
|
assert_eq!(creds.user, "john_doe");
|
||||||
assert_eq!(creds.dbname, "world");
|
assert_eq!(creds.dbname, "world");
|
||||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||||
@@ -227,8 +220,7 @@ mod tests {
|
|||||||
let sni = Some("second.localhost");
|
let sni = Some("second.localhost");
|
||||||
let common_name = Some("localhost");
|
let common_name = Some("localhost");
|
||||||
|
|
||||||
let err =
|
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
|
||||||
match err {
|
match err {
|
||||||
InconsistentProjectNames { domain, option } => {
|
InconsistentProjectNames { domain, option } => {
|
||||||
assert_eq!(option, "first");
|
assert_eq!(option, "first");
|
||||||
@@ -245,8 +237,7 @@ mod tests {
|
|||||||
let sni = Some("project.localhost");
|
let sni = Some("project.localhost");
|
||||||
let common_name = Some("example.com");
|
let common_name = Some("example.com");
|
||||||
|
|
||||||
let err =
|
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
|
||||||
match err {
|
match err {
|
||||||
InconsistentSni { sni, cn } => {
|
InconsistentSni { sni, cn } => {
|
||||||
assert_eq!(sni, "project.localhost");
|
assert_eq!(sni, "project.localhost");
|
||||||
|
|||||||
@@ -25,12 +25,11 @@ impl CancelMap {
|
|||||||
cancel_closure.try_cancel_query().await
|
cancel_closure.try_cancel_query().await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
/// Create a new session, with a new client-facing random cancellation key.
|
||||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
///
|
||||||
where
|
/// Use `enable_query_cancellation` to register a database cancellation
|
||||||
F: FnOnce(Session<'a>) -> R,
|
/// key with it, and to get the client-facing key.
|
||||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
|
||||||
{
|
|
||||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||||
// expose it and we don't want to do another roundtrip to query
|
// expose it and we don't want to do another roundtrip to query
|
||||||
// for it. The client will be able to notice that this is not the
|
// for it. The client will be able to notice that this is not the
|
||||||
@@ -44,17 +43,9 @@ impl CancelMap {
|
|||||||
.lock()
|
.lock()
|
||||||
.try_insert(key, None)
|
.try_insert(key, None)
|
||||||
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
|
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
|
||||||
|
|
||||||
// This will guarantee that the session gets dropped
|
|
||||||
// as soon as the future is finished.
|
|
||||||
scopeguard::defer! {
|
|
||||||
self.0.lock().remove(&key);
|
|
||||||
info!("dropped query cancellation key {key}");
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("registered new query cancellation key {key}");
|
info!("registered new query cancellation key {key}");
|
||||||
let session = Session::new(key, self);
|
|
||||||
f(session).await
|
Ok(Session::new(key, self))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
|
|||||||
impl Session<'_> {
|
impl Session<'_> {
|
||||||
/// Store the cancel token for the given session.
|
/// Store the cancel token for the given session.
|
||||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||||
info!("enabling query cancellation for this session");
|
info!("enabling query cancellation for this session");
|
||||||
self.cancel_map
|
self.cancel_map
|
||||||
.0
|
.0
|
||||||
@@ -122,6 +113,14 @@ impl Session<'_> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for Session<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let key = &self.key;
|
||||||
|
self.cancel_map.0.lock().remove(key);
|
||||||
|
info!("dropped query cancellation key {key}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -132,14 +131,14 @@ mod tests {
|
|||||||
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
|
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
|
|
||||||
|
let session = CANCEL_MAP.new_session()?;
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
assert!(CANCEL_MAP.contains(&session));
|
assert!(CANCEL_MAP.contains(&session));
|
||||||
|
|
||||||
tx.send(()).expect("failed to send");
|
tx.send(()).expect("failed to send");
|
||||||
futures::future::pending::<()>().await; // sleep forever
|
futures::future::pending::<()>().await; // sleep forever
|
||||||
|
});
|
||||||
Ok(())
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Wait until the task has been spawned.
|
// Wait until the task has been spawned.
|
||||||
rx.await.context("failed to hear from the task")?;
|
rx.await.context("failed to hear from the task")?;
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use once_cell::sync::Lazy;
|
|||||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::{error, info, info_span, Instrument};
|
use tracing::{error, info, info_span, instrument, Instrument};
|
||||||
|
|
||||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||||
@@ -71,17 +71,35 @@ pub async fn task_main(
|
|||||||
.set_nodelay(true)
|
.set_nodelay(true)
|
||||||
.context("failed to set socket option")?;
|
.context("failed to set socket option")?;
|
||||||
|
|
||||||
handle_client(config, &cancel_map, session_id, socket).await
|
handle_postgres_client(config, &cancel_map, session_id, socket).await
|
||||||
}
|
}
|
||||||
.unwrap_or_else(|e| {
|
.unwrap_or_else(|e| {
|
||||||
// Acknowledge that the task has finished with an error.
|
// Acknowledge that the task has finished with an error.
|
||||||
error!("per-client task finished with an error: {e:#}");
|
error!("per-client task finished with an error: {e:#}");
|
||||||
})
|
}),
|
||||||
.instrument(info_span!("client", session = format_args!("{session_id}"))),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle an incoming PostgreSQL connection
|
||||||
|
pub async fn handle_postgres_client(
|
||||||
|
config: &ProxyConfig,
|
||||||
|
cancel_map: &CancelMap,
|
||||||
|
session_id: uuid::Uuid,
|
||||||
|
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
handle_client(
|
||||||
|
config,
|
||||||
|
cancel_map,
|
||||||
|
session_id,
|
||||||
|
stream,
|
||||||
|
HostnameMethod::Sni,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle an incoming Postgres connection that's wrapped in websocket
|
||||||
pub async fn handle_ws_client(
|
pub async fn handle_ws_client(
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
cancel_map: &CancelMap,
|
cancel_map: &CancelMap,
|
||||||
@@ -89,45 +107,32 @@ pub async fn handle_ws_client(
|
|||||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
hostname: Option<String>,
|
hostname: Option<String>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// The `closed` counter will increase when this future is destroyed.
|
handle_client(
|
||||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
config,
|
||||||
scopeguard::defer! {
|
cancel_map,
|
||||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
session_id,
|
||||||
}
|
stream,
|
||||||
|
HostnameMethod::Param(hostname),
|
||||||
let tls = config.tls_config.as_ref();
|
true,
|
||||||
let hostname = hostname.as_deref();
|
)
|
||||||
|
.await
|
||||||
// TLS is None here, because the connection is already encrypted.
|
|
||||||
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
|
|
||||||
let (mut stream, params) = match do_handshake.await? {
|
|
||||||
Some(x) => x,
|
|
||||||
None => return Ok(()), // it's a cancellation request
|
|
||||||
};
|
|
||||||
|
|
||||||
// Extract credentials which we're going to use for auth.
|
|
||||||
let creds = {
|
|
||||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
|
||||||
let result = config
|
|
||||||
.auth_backend
|
|
||||||
.as_ref()
|
|
||||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true))
|
|
||||||
.transpose();
|
|
||||||
|
|
||||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
|
||||||
cancel_map
|
|
||||||
.with_session(|session| client.connect_to_db(session))
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum HostnameMethod {
|
||||||
|
Param(Option<String>),
|
||||||
|
Sni,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle an incoming client connection, handshake and authentication.
|
||||||
|
/// After that, keeps forwarding all the data. This doesn't return until the
|
||||||
|
/// connection is lost.
|
||||||
async fn handle_client(
|
async fn handle_client(
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
cancel_map: &CancelMap,
|
cancel_map: &CancelMap,
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
raw_stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
hostname_method: HostnameMethod,
|
||||||
|
use_cleartext_password_flow: bool,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// The `closed` counter will increase when this future is destroyed.
|
// The `closed` counter will increase when this future is destroyed.
|
||||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||||
@@ -135,36 +140,73 @@ async fn handle_client(
|
|||||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||||
}
|
}
|
||||||
|
|
||||||
let tls = config.tls_config.as_ref();
|
// Accept the connection from the client, authenticate it, and establish
|
||||||
let do_handshake = handshake(stream, tls, cancel_map).instrument(info_span!("handshake"));
|
// connection to the database.
|
||||||
let (mut stream, params) = match do_handshake.await? {
|
//
|
||||||
Some(x) => x,
|
// We cover all these activities in a one tracing span, so that they are
|
||||||
None => return Ok(()), // it's a cancellation request
|
// traced as one request. That makes it convenient to investigate where
|
||||||
};
|
// the time is spent when establishing a new connection. After the
|
||||||
|
// connection has been established, we exit the span, and use a separate
|
||||||
|
// span for the (rest of the) duration of the connection.
|
||||||
|
let conn = async {
|
||||||
|
// Process postgres startup packet and upgrade to TLS (if applicable)
|
||||||
|
let tls = config.tls_config.as_ref();
|
||||||
|
let (mut stream, params) = match handshake(raw_stream, tls, cancel_map).await? {
|
||||||
|
Some(x) => x,
|
||||||
|
None => return Ok::<_, anyhow::Error>(None), // it's a cancellation request
|
||||||
|
};
|
||||||
|
|
||||||
// Extract credentials which we're going to use for auth.
|
// Extract credentials which we're going to use for auth.
|
||||||
let creds = {
|
let creds = {
|
||||||
let sni = stream.get_ref().sni_hostname();
|
let sni = match &hostname_method {
|
||||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
HostnameMethod::Param(hostname) => hostname.as_deref(),
|
||||||
let result = config
|
HostnameMethod::Sni => stream.get_ref().sni_hostname(),
|
||||||
.auth_backend
|
};
|
||||||
.as_ref()
|
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false))
|
let result = config
|
||||||
.transpose();
|
.auth_backend
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||||
|
.transpose();
|
||||||
|
|
||||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
Ok(Some(
|
||||||
cancel_map
|
EstablishedConnection::connect_to_db(
|
||||||
.with_session(|session| client.connect_to_db(session))
|
stream,
|
||||||
.await
|
creds,
|
||||||
|
¶ms,
|
||||||
|
session_id,
|
||||||
|
use_cleartext_password_flow,
|
||||||
|
cancel_map,
|
||||||
|
)
|
||||||
|
.await?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
.instrument(info_span!("establish_connection", session_id=%session_id))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
match conn {
|
||||||
|
Some(conn) => {
|
||||||
|
// Connection established in both ways. Forward all traffic until the
|
||||||
|
// either connection is lost.
|
||||||
|
conn.handle_connection()
|
||||||
|
.instrument(info_span!("forward", session_id=%session_id))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// It was a cancellation request. It was handled in 'handshake' already.
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Establish a (most probably, secure) connection with the client.
|
/// Establish a (most probably, secure) connection with the client.
|
||||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||||
|
#[instrument(skip_all)]
|
||||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
stream: S,
|
stream: S,
|
||||||
mut tls: Option<&TlsConfig>,
|
mut tls: Option<&TlsConfig>,
|
||||||
@@ -227,43 +269,36 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Thin connection context.
|
/// Thin connection context.
|
||||||
struct Client<'a, S> {
|
struct EstablishedConnection<'a, S> {
|
||||||
/// The underlying libpq protocol stream.
|
client_stream: MeasuredStream<S>,
|
||||||
stream: PqStream<S>,
|
db_stream: MeasuredStream<tokio::net::TcpStream>,
|
||||||
/// Client credentials that we care about.
|
|
||||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
/// Hold on to the Session for as long as the connection is alive, so that
|
||||||
/// KV-dictionary with PostgreSQL connection params.
|
/// it can be cancelled.
|
||||||
params: &'a StartupMessageParams,
|
_session: cancellation::Session<'a>,
|
||||||
/// Unique connection ID.
|
|
||||||
session_id: uuid::Uuid,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, S> Client<'a, S> {
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send> EstablishedConnection<'_, S> {
|
||||||
/// Construct a new connection context.
|
async fn handle_connection(mut self) -> anyhow::Result<()> {
|
||||||
fn new(
|
// Starting from here we only proxy the client's traffic.
|
||||||
stream: PqStream<S>,
|
info!("performing the proxy pass...");
|
||||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
let _ = tokio::io::copy_bidirectional(&mut self.client_stream, &mut self.db_stream).await?;
|
||||||
params: &'a StartupMessageParams,
|
Ok(())
|
||||||
session_id: uuid::Uuid,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
stream,
|
|
||||||
creds,
|
|
||||||
params,
|
|
||||||
session_id,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
|
||||||
/// Let the client authenticate and connect to the designated compute node.
|
/// Let the client authenticate and connect to the designated compute node.
|
||||||
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
|
/// On return, the connection is fully established in both ways, and we can start
|
||||||
let Self {
|
/// forwarding the bytes.
|
||||||
mut stream,
|
#[instrument(skip_all)]
|
||||||
creds,
|
async fn connect_to_db<'a>(
|
||||||
params,
|
mut stream: PqStream<S>,
|
||||||
session_id,
|
creds: auth::BackendType<'a, auth::ClientCredentials<'_>>,
|
||||||
} = self;
|
params: &'_ StartupMessageParams,
|
||||||
|
session_id: uuid::Uuid,
|
||||||
|
use_cleartext_password_flow: bool,
|
||||||
|
cancel_map: &'a CancelMap,
|
||||||
|
) -> anyhow::Result<EstablishedConnection<'a, S>> {
|
||||||
|
let session = cancel_map.new_session()?;
|
||||||
|
|
||||||
let extra = auth::ConsoleReqExtra {
|
let extra = auth::ConsoleReqExtra {
|
||||||
session_id, // aka this connection's id
|
session_id, // aka this connection's id
|
||||||
@@ -272,10 +307,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
|||||||
|
|
||||||
let auth_result = async {
|
let auth_result = async {
|
||||||
// `&mut stream` doesn't let us merge those 2 lines.
|
// `&mut stream` doesn't let us merge those 2 lines.
|
||||||
let res = creds.authenticate(&extra, &mut stream).await;
|
let res = creds
|
||||||
|
.authenticate(&extra, &mut stream, use_cleartext_password_flow)
|
||||||
|
.await;
|
||||||
async { res }.or_else(|e| stream.throw_error(e)).await
|
async { res }.or_else(|e| stream.throw_error(e)).await
|
||||||
}
|
}
|
||||||
.instrument(info_span!("auth"))
|
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let node = auth_result.value;
|
let node = auth_result.value;
|
||||||
@@ -311,21 +347,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
|
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
|
||||||
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
|
let client_stream = MeasuredStream::new(stream.into_inner(), m_sent);
|
||||||
// Number of bytes we sent to the client (outbound).
|
|
||||||
m_sent.inc_by(cnt as u64);
|
|
||||||
});
|
|
||||||
|
|
||||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
|
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
|
||||||
let mut db = MeasuredStream::new(db.stream, |cnt| {
|
let db_stream = MeasuredStream::new(db.stream, m_recv);
|
||||||
// Number of bytes the client sent to the compute node (inbound).
|
|
||||||
m_recv.inc_by(cnt as u64);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Starting from here we only proxy the client's traffic.
|
Ok(EstablishedConnection {
|
||||||
info!("performing the proxy pass...");
|
client_stream,
|
||||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
db_stream,
|
||||||
|
_session: session,
|
||||||
Ok(())
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use std::{io, task};
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
use tokio_rustls::server::TlsStream;
|
use tokio_rustls::server::TlsStream;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
/// Stream wrapper which implements libpq's protocol.
|
/// Stream wrapper which implements libpq's protocol.
|
||||||
@@ -105,6 +106,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
|||||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||||
|
#[instrument(skip_all)]
|
||||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||||
tracing::info!("forwarding error to user: {error}");
|
tracing::info!("forwarding error to user: {error}");
|
||||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||||
@@ -114,6 +116,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
|||||||
|
|
||||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||||
|
#[instrument(skip_all)]
|
||||||
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
||||||
where
|
where
|
||||||
E: UserFacingError + Into<anyhow::Error>,
|
E: UserFacingError + Into<anyhow::Error>,
|
||||||
@@ -228,27 +231,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
/// This stream tracks all writes and calls user provided
|
/// This stream tracks all writes, and whenever the stream is flushed,
|
||||||
/// callback when the underlying stream is flushed.
|
/// increments the user-provided counter by the number of bytes flushed.
|
||||||
pub struct MeasuredStream<S, W> {
|
pub struct MeasuredStream<S> {
|
||||||
#[pin]
|
#[pin]
|
||||||
stream: S,
|
stream: S,
|
||||||
write_count: usize,
|
write_count: usize,
|
||||||
inc_write_count: W,
|
write_counter: prometheus::IntCounter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, W> MeasuredStream<S, W> {
|
impl<S> MeasuredStream<S> {
|
||||||
pub fn new(stream: S, inc_write_count: W) -> Self {
|
pub fn new(stream: S, write_counter: prometheus::IntCounter) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream,
|
stream,
|
||||||
write_count: 0,
|
write_count: 0,
|
||||||
inc_write_count,
|
write_counter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
|
impl<S: AsyncRead + Unpin> AsyncRead for MeasuredStream<S> {
|
||||||
fn poll_read(
|
fn poll_read(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
context: &mut task::Context<'_>,
|
context: &mut task::Context<'_>,
|
||||||
@@ -258,7 +261,7 @@ impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
|
impl<S: AsyncWrite + Unpin> AsyncWrite for MeasuredStream<S> {
|
||||||
fn poll_write(
|
fn poll_write(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
context: &mut task::Context<'_>,
|
context: &mut task::Context<'_>,
|
||||||
@@ -279,7 +282,7 @@ impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W>
|
|||||||
let this = self.project();
|
let this = self.project();
|
||||||
this.stream.poll_flush(context).map_ok(|()| {
|
this.stream.poll_flush(context).map_ok(|()| {
|
||||||
// Call the user provided callback and reset the write count.
|
// Call the user provided callback and reset the write count.
|
||||||
(this.inc_write_count)(*this.write_count);
|
this.write_counter.inc_by(*this.write_count as u64);
|
||||||
*this.write_count = 0;
|
*this.write_count = 0;
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user