mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 22:12:56 +00:00
[proxy] Replace private static map with a public CancelMap
This is a cleaner approach which might facilitate testing.
This commit is contained in:
@@ -1,15 +1,57 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::Mutex;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use zenith_utils::pq_proto::CancelKeyData;
|
||||
|
||||
lazy_static! {
|
||||
/// Enables serving CancelRequests.
|
||||
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, Option<CancelClosure>>> = Default::default();
|
||||
/// Enables serving CancelRequests.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
|
||||
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.lock()
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("unknown session: {:?}", key))?;
|
||||
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session<'a>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.lock().remove(&key);
|
||||
}
|
||||
|
||||
let session = Session::new(key, self);
|
||||
f(session).await
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
@@ -38,54 +80,27 @@ impl CancelClosure {
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(key: CancelKeyData) -> anyhow::Result<()> {
|
||||
let cancel_closure = CANCEL_MAP
|
||||
.lock()
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("unknown session: {:?}", key))?;
|
||||
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session(CancelKeyData);
|
||||
pub struct Session<'a> {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: &'a CancelMap,
|
||||
}
|
||||
|
||||
impl<'a> Session<'a> {
|
||||
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
|
||||
Self { key, cancel_map }
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
CANCEL_MAP.lock().insert(self.0, Some(cancel_closure));
|
||||
self.0
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<F, R, V>(f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
CANCEL_MAP
|
||||
.lock()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
CANCEL_MAP.lock().remove(&key);
|
||||
}
|
||||
|
||||
let session = Session(key);
|
||||
f(session).await
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelClosure};
|
||||
use crate::cancellation::{self, CancelClosure, CancelMap};
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
@@ -47,22 +48,25 @@ pub async fn thread_main(
|
||||
println!("proxy has shut down");
|
||||
}
|
||||
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
|
||||
tokio::spawn(log_error(async {
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
tokio::spawn(log_error(async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(config, socket).await
|
||||
handle_client(config, &cancel_map, socket).await
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
@@ -72,11 +76,12 @@ async fn handle_client(
|
||||
}
|
||||
|
||||
let tls = config.tls_config.clone();
|
||||
if let Some((stream, creds)) = handshake(stream, tls).await? {
|
||||
cancellation::with_session(|session| async {
|
||||
connect_client_to_db(config, stream, creds, session).await
|
||||
})
|
||||
.await?;
|
||||
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
|
||||
cancel_map
|
||||
.with_session(|session| async {
|
||||
connect_client_to_db(config, session, client, creds).await
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -88,6 +93,7 @@ async fn handle_client(
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
@@ -136,7 +142,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
break Ok(Some((stream, params.try_into()?)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancellation::cancel_session(cancel_key_data).await?;
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
break Ok(None);
|
||||
}
|
||||
@@ -146,9 +152,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
async fn connect_client_to_db(
|
||||
config: &ProxyConfig,
|
||||
session: cancellation::Session<'_>,
|
||||
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: auth::ClientCredentials,
|
||||
session: cancellation::Session,
|
||||
) -> anyhow::Result<()> {
|
||||
let db_info = creds.authenticate(config, &mut client).await?;
|
||||
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
|
||||
@@ -211,8 +217,12 @@ mod tests {
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tls: Option<TlsConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
// TODO: add some infra + tests for credentials
|
||||
let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?;
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("no stream")?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
|
||||
Reference in New Issue
Block a user