[proxy] Replace private static map with a public CancelMap

This is a cleaner approach which might facilitate testing.
This commit is contained in:
Dmitry Ivanov
2022-02-16 18:51:18 +03:00
parent a47dade622
commit a26d565282
2 changed files with 85 additions and 60 deletions

View File

@@ -1,10 +1,11 @@
use crate::auth;
use crate::cancellation::{self, CancelClosure};
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
@@ -47,22 +48,25 @@ pub async fn thread_main(
println!("proxy has shut down");
}
let cancel_map = Arc::new(CancelMap::default());
loop {
let (socket, peer_addr) = listener.accept().await?;
println!("accepted connection from {}", peer_addr);
tokio::spawn(log_error(async {
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(log_error(async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, socket).await
handle_client(config, &cancel_map, socket).await
}));
}
}
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
@@ -72,11 +76,12 @@ async fn handle_client(
}
let tls = config.tls_config.clone();
if let Some((stream, creds)) = handshake(stream, tls).await? {
cancellation::with_session(|session| async {
connect_client_to_db(config, stream, creds, session).await
})
.await?;
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
cancel_map
.with_session(|session| async {
connect_client_to_db(config, session, client, creds).await
})
.await?;
}
Ok(())
@@ -88,6 +93,7 @@ async fn handle_client(
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -136,7 +142,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {
cancellation::cancel_session(cancel_key_data).await?;
cancel_map.cancel_session(cancel_key_data).await?;
break Ok(None);
}
@@ -146,9 +152,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn connect_client_to_db(
config: &ProxyConfig,
session: cancellation::Session<'_>,
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: auth::ClientCredentials,
session: cancellation::Session,
) -> anyhow::Result<()> {
let db_info = creds.authenticate(config, &mut client).await?;
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
@@ -211,8 +217,12 @@ mod tests {
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<TlsConfig>,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?;
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("no stream")?;
stream
.write_message_noflush(&Be::AuthenticationOk)?