Compare commits

...

3 Commits

Author SHA1 Message Date
Conrad Ludgate
0e551edb06 rename fn 2024-08-16 08:54:25 +01:00
Conrad Ludgate
484cdccbf2 fix cancellation 2024-08-16 08:44:03 +01:00
Conrad Ludgate
39d1b78817 fix http connect config 2024-08-16 08:42:14 +01:00
5 changed files with 100 additions and 47 deletions

View File

@@ -151,21 +151,34 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
#[derive(Clone)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: Option<CancelToken>,
}
impl CancelClosure {
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
Self {
socket_addr,
cancel_token,
cancel_token: Some(cancel_token),
}
}
#[cfg(test)]
pub fn test() -> Self {
use std::net::{Ipv4Addr, SocketAddrV4};
Self {
socket_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from_bits(0), 0)),
cancel_token: None,
}
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
info!("query was cancelled");
if let Some(cancel_token) = self.cancel_token {
let socket = TcpStream::connect(self.socket_addr).await?;
cancel_token.cancel_query_raw(socket, NoTls).await?;
info!("query was cancelled");
}
Ok(())
}
}

View File

@@ -16,8 +16,10 @@ use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError}
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
use tokio_postgres::{
tls::{MakeTlsConnect, NoTlsError},
Client, Connection,
};
use tracing::{error, info, warn};
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -42,6 +44,12 @@ pub enum ConnectionError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
impl From<NoTlsError> for ConnectionError {
fn from(value: NoTlsError) -> Self {
Self::CouldNotConnect(io::Error::new(io::ErrorKind::Other, value.to_string()))
}
}
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;
@@ -273,6 +281,30 @@ pub struct PostgresConnection {
}
impl ConnCfg {
/// Connect to a corresponding compute node.
pub async fn managed_connect<M: MakeTlsConnect<tokio::net::TcpStream>>(
&self,
ctx: &RequestMonitoring,
timeout: Duration,
mktls: &mut M,
) -> Result<(SocketAddr, Client, Connection<TcpStream, M::Stream>), ConnectionError>
where
ConnectionError: From<M::Error>,
{
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
let tls = mktls.make_tls_connect(host)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
drop(pause);
Ok((socket_addr, client, connection))
}
/// Connect to a corresponding compute node.
pub async fn connect(
&self,
@@ -281,10 +313,6 @@ impl ConnCfg {
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
let client_config = if allow_self_signed_compute {
// Allow all certificates for creating the connection
let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
@@ -298,21 +326,15 @@ impl ConnCfg {
let client_config = client_config.with_no_client_auth();
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
drop(pause);
let (socket_addr, client, connection) =
self.managed_connect(ctx, timeout, &mut mk_tls).await?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {host} ({socket_addr}) sslmode={:?}",
"connected to compute node ({socket_addr}) sslmode={:?}",
self.0.get_ssl_mode()
);

View File

@@ -5,7 +5,8 @@ use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
compute,
cancellation::CancelClosure,
compute::{self, ConnectionError},
config::{AuthenticationConfig, ProxyConfig},
console::{
errors::{GetAuthInfoError, WakeComputeError},
@@ -142,7 +143,7 @@ pub enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
ConnectionError(#[from] ConnectionError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -229,17 +230,16 @@ impl ConnectMechanism for TokioMechanism {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.password(&*self.conn_info.password)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
let (socket_addr, client, connection) = permit.release_result(
node_info
.config
.managed_connect(ctx, timeout, &mut tokio_postgres::NoTls)
.await,
)?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
Ok(poll_client(
@@ -250,8 +250,14 @@ impl ConnectMechanism for TokioMechanism {
connection,
self.conn_id,
node_info.aux.clone(),
cancel_closure,
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.password(&self.conn_info.password);
}
}

View File

@@ -12,11 +12,13 @@ use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use tokio_util::sync::CancellationToken;
use crate::cancellation::CancelClosure;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
@@ -463,14 +465,16 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
}
}
#[allow(clippy::too_many_arguments)]
pub fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: C,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
cancel_closure: CancelClosure,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
@@ -572,6 +576,7 @@ pub fn poll_client<C: ClientInnerExt>(
cancel,
aux,
conn_id,
cancel_closure,
};
Client::new(inner, conn_info, pool_clone)
}
@@ -582,6 +587,7 @@ struct ClientInner<C: ClientInnerExt> {
cancel: CancellationToken,
aux: MetricsAuxInfo,
conn_id: uuid::Uuid,
cancel_closure: CancelClosure,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
@@ -646,7 +652,7 @@ impl<C: ClientInnerExt> Client<C> {
pool,
}
}
pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
pub fn inner(&mut self) -> (&mut C, &CancelClosure, Discard<'_, C>) {
let Self {
inner,
pool,
@@ -654,7 +660,11 @@ impl<C: ClientInnerExt> Client<C> {
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { pool, conn_info })
(
&mut inner.inner,
&inner.cancel_closure,
Discard { pool, conn_info },
)
}
}
@@ -751,6 +761,7 @@ mod tests {
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
},
conn_id: uuid::Uuid::new_v4(),
cancel_closure: CancelClosure::test(),
}
}
@@ -785,7 +796,7 @@ mod tests {
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner().1.discard();
client.inner().2.discard();
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}

View File

@@ -26,7 +26,6 @@ use tokio_postgres::error::ErrorPosition;
use tokio_postgres::error::SqlState;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::NoTls;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tokio_util::sync::CancellationToken;
@@ -261,7 +260,9 @@ pub async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(
crate::compute::ConnectionError::Postgres(e),
))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -622,8 +623,7 @@ impl QueryData {
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let (inner, cancel_token, mut discard) = client.inner();
let res = match select(
pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)),
@@ -647,7 +647,7 @@ impl QueryData {
// The query was cancelled.
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
if let Err(err) = cancel_token.clone().try_cancel_query().await {
tracing::error!(?err, "could not cancel query");
}
// wait for the query cancellation
@@ -663,7 +663,9 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(
crate::compute::ConnectionError::Postgres(e),
))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -694,8 +696,7 @@ impl BatchQueryData {
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let (inner, cancel_token, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
builder = builder.isolation_level(isolation_level);
@@ -728,7 +729,7 @@ impl BatchQueryData {
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
if let Err(err) = cancel_token.clone().try_cancel_query().await {
tracing::error!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.