mirror of
https://github.com/neondatabase/neon.git
synced 2026-03-17 15:20:37 +00:00
Compare commits
4 Commits
main
...
cloneable/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f6a971e4f | ||
|
|
3e962ea400 | ||
|
|
68550d3040 | ||
|
|
f8d530d031 |
@@ -28,10 +28,9 @@ use crate::context::RequestContext;
|
|||||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||||
use crate::pqproto::FeStartupPacket;
|
use crate::pqproto::FeStartupPacket;
|
||||||
use crate::protocol2::ConnectionInfo;
|
use crate::protocol2::ConnectionInfo;
|
||||||
use crate::proxy::{
|
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
|
||||||
};
|
|
||||||
use crate::stream::{PqStream, Stream};
|
use crate::stream::{PqStream, Stream};
|
||||||
|
use crate::util::run_until_cancelled;
|
||||||
|
|
||||||
project_git_version!(GIT_VERSION);
|
project_git_version!(GIT_VERSION);
|
||||||
|
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ pub async fn run() -> anyhow::Result<()> {
|
|||||||
match auth_backend {
|
match auth_backend {
|
||||||
Either::Left(auth_backend) => {
|
Either::Left(auth_backend) => {
|
||||||
if let Some(proxy_listener) = proxy_listener {
|
if let Some(proxy_listener) = proxy_listener {
|
||||||
client_tasks.spawn(crate::proxy::task_main(
|
client_tasks.spawn(crate::pglb::task_main(
|
||||||
config,
|
config,
|
||||||
auth_backend,
|
auth_backend,
|
||||||
proxy_listener,
|
proxy_listener,
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ pub enum Auth {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A config for authenticating to the compute node.
|
/// A config for authenticating to the compute node.
|
||||||
|
// XXX: clone
|
||||||
|
#[derive(Clone)]
|
||||||
pub(crate) struct AuthInfo {
|
pub(crate) struct AuthInfo {
|
||||||
/// None for local-proxy, as we use trust-based localhost auth.
|
/// None for local-proxy, as we use trust-based localhost auth.
|
||||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
|||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::error::ReportableError;
|
use crate::error::ReportableError;
|
||||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
use crate::pglb::connect_compute::TcpMechanism;
|
||||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||||
use crate::pglb::passthrough::ProxyPassthrough;
|
use crate::pglb::passthrough::ProxyPassthrough;
|
||||||
|
use crate::pglb::{ClientRequestError, ErrorSource};
|
||||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||||
use crate::proxy::{
|
use crate::proxy::{connect_to_compute, prepare_client_connection};
|
||||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
use crate::util::run_until_cancelled;
|
||||||
};
|
|
||||||
|
|
||||||
pub async fn task_main(
|
pub async fn task_main(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
|
|||||||
@@ -106,4 +106,5 @@ mod tls;
|
|||||||
mod types;
|
mod types;
|
||||||
mod url;
|
mod url;
|
||||||
mod usage_metrics;
|
mod usage_metrics;
|
||||||
|
mod util;
|
||||||
mod waiters;
|
mod waiters;
|
||||||
|
|||||||
@@ -1,41 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::time;
|
|
||||||
use tracing::{debug, info, warn};
|
|
||||||
|
|
||||||
use crate::auth::backend::ComputeUserInfo;
|
use crate::auth::backend::ComputeUserInfo;
|
||||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
use crate::compute::{self, AuthInfo, PostgresConnection};
|
||||||
use crate::config::{ComputeConfig, RetryConfig};
|
use crate::config::ComputeConfig;
|
||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::control_plane::errors::WakeComputeError;
|
|
||||||
use crate::control_plane::locks::ApiLocks;
|
use crate::control_plane::locks::ApiLocks;
|
||||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
use crate::control_plane::{self, CachedNodeInfo};
|
||||||
use crate::error::ReportableError;
|
use crate::error::ReportableError;
|
||||||
use crate::metrics::{
|
|
||||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
|
||||||
};
|
|
||||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
|
||||||
use crate::proxy::wake_compute::wake_compute;
|
|
||||||
use crate::types::Host;
|
use crate::types::Host;
|
||||||
|
|
||||||
/// If we couldn't connect, a cached connection info might be to blame
|
|
||||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
|
||||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
|
||||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
|
||||||
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
|
|
||||||
let is_cached = node_info.cached();
|
|
||||||
if is_cached {
|
|
||||||
warn!("invalidating stalled compute node info cache entry");
|
|
||||||
}
|
|
||||||
let label = if is_cached {
|
|
||||||
ConnectionFailureKind::ComputeCached
|
|
||||||
} else {
|
|
||||||
ConnectionFailureKind::ComputeUncached
|
|
||||||
};
|
|
||||||
Metrics::get().proxy.connection_failures_total.inc(label);
|
|
||||||
|
|
||||||
node_info.invalidate()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait ConnectMechanism {
|
pub(crate) trait ConnectMechanism {
|
||||||
type Connection;
|
type Connection;
|
||||||
@@ -88,106 +61,3 @@ impl ConnectMechanism for TcpMechanism {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to connect to the compute node, retrying if necessary.
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
|
||||||
ctx: &RequestContext,
|
|
||||||
mechanism: &M,
|
|
||||||
user_info: &B,
|
|
||||||
wake_compute_retry_config: RetryConfig,
|
|
||||||
compute: &ComputeConfig,
|
|
||||||
) -> Result<M::Connection, M::Error>
|
|
||||||
where
|
|
||||||
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
|
|
||||||
M::Error: From<WakeComputeError>,
|
|
||||||
{
|
|
||||||
let mut num_retries = 0;
|
|
||||||
let node_info =
|
|
||||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
|
||||||
|
|
||||||
// try once
|
|
||||||
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
|
||||||
Ok(res) => {
|
|
||||||
ctx.success();
|
|
||||||
Metrics::get().proxy.retries_metric.observe(
|
|
||||||
RetriesMetricGroup {
|
|
||||||
outcome: ConnectOutcome::Success,
|
|
||||||
retry_type: RetryType::ConnectToCompute,
|
|
||||||
},
|
|
||||||
num_retries.into(),
|
|
||||||
);
|
|
||||||
return Ok(res);
|
|
||||||
}
|
|
||||||
Err(e) => e,
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!(error = ?err, COULD_NOT_CONNECT);
|
|
||||||
|
|
||||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
|
||||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
|
||||||
// Do not need to retrieve a new node_info, just return the old one.
|
|
||||||
if should_retry(&err, num_retries, compute.retry) {
|
|
||||||
Metrics::get().proxy.retries_metric.observe(
|
|
||||||
RetriesMetricGroup {
|
|
||||||
outcome: ConnectOutcome::Failed,
|
|
||||||
retry_type: RetryType::ConnectToCompute,
|
|
||||||
},
|
|
||||||
num_retries.into(),
|
|
||||||
);
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
node_info
|
|
||||||
} else {
|
|
||||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
|
||||||
debug!("compute node's state has likely changed; requesting a wake-up");
|
|
||||||
invalidate_cache(node_info);
|
|
||||||
// TODO: increment num_retries?
|
|
||||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
// now that we have a new node, try connect to it repeatedly.
|
|
||||||
// this can error for a few reasons, for instance:
|
|
||||||
// * DNS connection settings haven't quite propagated yet
|
|
||||||
debug!("wake_compute success. attempting to connect");
|
|
||||||
num_retries = 1;
|
|
||||||
loop {
|
|
||||||
match mechanism.connect_once(ctx, &node_info, compute).await {
|
|
||||||
Ok(res) => {
|
|
||||||
ctx.success();
|
|
||||||
Metrics::get().proxy.retries_metric.observe(
|
|
||||||
RetriesMetricGroup {
|
|
||||||
outcome: ConnectOutcome::Success,
|
|
||||||
retry_type: RetryType::ConnectToCompute,
|
|
||||||
},
|
|
||||||
num_retries.into(),
|
|
||||||
);
|
|
||||||
// TODO: is this necessary? We have a metric.
|
|
||||||
info!(?num_retries, "connected to compute node after");
|
|
||||||
return Ok(res);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if !should_retry(&e, num_retries, compute.retry) {
|
|
||||||
// Don't log an error here, caller will print the error
|
|
||||||
Metrics::get().proxy.retries_metric.observe(
|
|
||||||
RetriesMetricGroup {
|
|
||||||
outcome: ConnectOutcome::Failed,
|
|
||||||
retry_type: RetryType::ConnectToCompute,
|
|
||||||
},
|
|
||||||
num_retries.into(),
|
|
||||||
);
|
|
||||||
return Err(e.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let wait_duration = retry_after(num_retries, compute.retry);
|
|
||||||
num_retries += 1;
|
|
||||||
|
|
||||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
|
||||||
time::sleep(wait_duration).await;
|
|
||||||
drop(pause);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,3 +3,360 @@ pub mod copy_bidirectional;
|
|||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
pub mod inprocess;
|
pub mod inprocess;
|
||||||
pub mod passthrough;
|
pub mod passthrough;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::FutureExt;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use once_cell::sync::OnceCell;
|
||||||
|
use regex::Regex;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio::time;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tracing::{Instrument, debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::cancellation::{self, CancellationHandler};
|
||||||
|
use crate::compute::COULD_NOT_CONNECT;
|
||||||
|
use crate::config::{ComputeConfig, ProxyConfig, ProxyProtocolV2, RetryConfig, TlsConfig};
|
||||||
|
use crate::context::RequestContext;
|
||||||
|
use crate::control_plane::NodeInfo;
|
||||||
|
use crate::control_plane::errors::WakeComputeError;
|
||||||
|
use crate::error::{ReportableError, UserFacingError};
|
||||||
|
use crate::metrics::{
|
||||||
|
ConnectOutcome, ConnectionFailureKind, Metrics, NumClientConnectionsGuard, RetriesMetricGroup,
|
||||||
|
RetryType,
|
||||||
|
};
|
||||||
|
use crate::pglb::connect_compute::{ComputeConnectBackend, ConnectMechanism, TcpMechanism};
|
||||||
|
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||||
|
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||||
|
use crate::pglb::passthrough::ProxyPassthrough;
|
||||||
|
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||||
|
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||||
|
use crate::proxy::handle_connect_request;
|
||||||
|
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
||||||
|
use crate::proxy::wake_compute::wake_compute;
|
||||||
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
|
use crate::stream::{PqStream, Stream};
|
||||||
|
use crate::types::EndpointCacheKey;
|
||||||
|
use crate::util::run_until_cancelled;
|
||||||
|
use crate::{auth, compute, control_plane};
|
||||||
|
|
||||||
|
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||||
|
pub struct TlsRequired;
|
||||||
|
|
||||||
|
impl ReportableError for TlsRequired {
|
||||||
|
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||||
|
crate::error::ErrorKind::User
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserFacingError for TlsRequired {}
|
||||||
|
|
||||||
|
pub async fn task_main(
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
auth_backend: &'static auth::Backend<'static, ()>,
|
||||||
|
listener: tokio::net::TcpListener,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
scopeguard::defer! {
|
||||||
|
info!("proxy has shut down");
|
||||||
|
}
|
||||||
|
|
||||||
|
// When set for the server socket, the keepalive setting
|
||||||
|
// will be inherited by all accepted client sockets.
|
||||||
|
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||||
|
|
||||||
|
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||||
|
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||||
|
|
||||||
|
while let Some(accept_result) =
|
||||||
|
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||||
|
{
|
||||||
|
let (socket, peer_addr) = accept_result?;
|
||||||
|
|
||||||
|
let conn_gauge = Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.client_connections
|
||||||
|
.guard(crate::metrics::Protocol::Tcp);
|
||||||
|
|
||||||
|
let session_id = uuid::Uuid::new_v4();
|
||||||
|
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||||
|
let cancellations = cancellations.clone();
|
||||||
|
|
||||||
|
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||||
|
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||||
|
|
||||||
|
connections.spawn(async move {
|
||||||
|
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||||
|
ProxyProtocolV2::Required => {
|
||||||
|
match read_proxy_protocol(socket).await {
|
||||||
|
Err(e) => {
|
||||||
|
warn!("per-client task finished with an error: {e:#}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// our load balancers will not send any more data. let's just exit immediately
|
||||||
|
Ok((_socket, ConnectHeader::Local)) => {
|
||||||
|
debug!("healthcheck received");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||||
|
// error later.
|
||||||
|
ProxyProtocolV2::Rejected => (
|
||||||
|
socket,
|
||||||
|
ConnectionInfo {
|
||||||
|
addr: peer_addr,
|
||||||
|
extra: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
match socket.set_nodelay(true) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctx = RequestContext::new(
|
||||||
|
session_id,
|
||||||
|
conn_info,
|
||||||
|
crate::metrics::Protocol::Tcp,
|
||||||
|
&config.region,
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = handle_client(
|
||||||
|
config,
|
||||||
|
auth_backend,
|
||||||
|
&ctx,
|
||||||
|
cancellation_handler,
|
||||||
|
socket,
|
||||||
|
ClientMode::Tcp,
|
||||||
|
endpoint_rate_limiter2,
|
||||||
|
conn_gauge,
|
||||||
|
cancellations,
|
||||||
|
)
|
||||||
|
.instrument(ctx.span())
|
||||||
|
.boxed()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Err(e) => {
|
||||||
|
ctx.set_error_kind(e.get_error_kind());
|
||||||
|
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
ctx.set_success();
|
||||||
|
}
|
||||||
|
Ok(Some(p)) => {
|
||||||
|
ctx.set_success();
|
||||||
|
let _disconnect = ctx.log_connect();
|
||||||
|
match p.proxy_pass(&config.connect_to_compute).await {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(ErrorSource::Client(e)) => {
|
||||||
|
warn!(
|
||||||
|
?session_id,
|
||||||
|
"per-client task finished with an IO error from the client: {e:#}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(ErrorSource::Compute(e)) => {
|
||||||
|
error!(
|
||||||
|
?session_id,
|
||||||
|
"per-client task finished with an IO error from the compute: {e:#}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
connections.close();
|
||||||
|
cancellations.close();
|
||||||
|
drop(listener);
|
||||||
|
|
||||||
|
// Drain connections
|
||||||
|
connections.wait().await;
|
||||||
|
cancellations.wait().await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) enum ClientMode {
|
||||||
|
Tcp,
|
||||||
|
Websockets { hostname: Option<String> },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Abstracts the logic of handling TCP vs WS clients
|
||||||
|
impl ClientMode {
|
||||||
|
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ClientMode::Tcp => false,
|
||||||
|
ClientMode::Websockets { .. } => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||||
|
match self {
|
||||||
|
ClientMode::Tcp => s.sni_hostname(),
|
||||||
|
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||||
|
match self {
|
||||||
|
ClientMode::Tcp => tls,
|
||||||
|
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||||
|
ClientMode::Websockets { .. } => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||||
|
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||||
|
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||||
|
// we cannot be sure the client even understands our error message
|
||||||
|
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||||
|
pub(crate) enum ClientRequestError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Cancellation(#[from] cancellation::CancelError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Handshake(#[from] HandshakeError),
|
||||||
|
#[error("{0}")]
|
||||||
|
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||||
|
#[error("{0}")]
|
||||||
|
PrepareClient(#[from] std::io::Error),
|
||||||
|
#[error("{0}")]
|
||||||
|
ReportedError(#[from] crate::stream::ReportedError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReportableError for ClientRequestError {
|
||||||
|
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||||
|
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||||
|
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||||
|
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||||
|
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
auth_backend: &'static auth::Backend<'static, ()>,
|
||||||
|
ctx: &RequestContext,
|
||||||
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
|
client: S,
|
||||||
|
mode: ClientMode,
|
||||||
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||||
|
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||||
|
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||||
|
debug!(
|
||||||
|
protocol = %ctx.protocol(),
|
||||||
|
"handling interactive connection from client"
|
||||||
|
);
|
||||||
|
|
||||||
|
let metrics = &Metrics::get().proxy;
|
||||||
|
let proto = ctx.protocol();
|
||||||
|
let request_gauge = metrics.connection_requests.guard(proto);
|
||||||
|
|
||||||
|
let tls = config.tls_config.load();
|
||||||
|
let tls = tls.as_deref();
|
||||||
|
|
||||||
|
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||||
|
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||||
|
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||||
|
|
||||||
|
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||||
|
.await??
|
||||||
|
{
|
||||||
|
HandshakeData::Startup(client, params) => (client, params),
|
||||||
|
HandshakeData::Cancel(cancel_key_data) => {
|
||||||
|
// spawn a task to cancel the session, but don't wait for it
|
||||||
|
cancellations.spawn({
|
||||||
|
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||||
|
let ctx = ctx.clone();
|
||||||
|
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||||
|
cancel_span.follows_from(tracing::Span::current());
|
||||||
|
async move {
|
||||||
|
cancellation_handler_clone
|
||||||
|
.cancel_session(
|
||||||
|
cancel_key_data,
|
||||||
|
ctx,
|
||||||
|
config.authentication_config.ip_allowlist_check_enabled,
|
||||||
|
config.authentication_config.is_vpc_acccess_proxy,
|
||||||
|
auth_backend.get_api(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||||
|
}.instrument(cancel_span)
|
||||||
|
});
|
||||||
|
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
drop(pause);
|
||||||
|
|
||||||
|
ctx.set_db_options(params.clone());
|
||||||
|
|
||||||
|
let hostname = mode.hostname(client.get_ref());
|
||||||
|
|
||||||
|
let common_names = tls.map(|tls| &tls.common_names);
|
||||||
|
|
||||||
|
let private_link_id = match ctx.extra() {
|
||||||
|
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||||
|
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = handle_connect_request(
|
||||||
|
config,
|
||||||
|
auth_backend,
|
||||||
|
ctx,
|
||||||
|
cancellation_handler,
|
||||||
|
client,
|
||||||
|
mode,
|
||||||
|
endpoint_rate_limiter,
|
||||||
|
¶ms,
|
||||||
|
hostname,
|
||||||
|
common_names,
|
||||||
|
async |ctx, node_info, auth_info, creds, compute_config| {
|
||||||
|
let mech = &TcpMechanism {
|
||||||
|
user_info: creds.info.clone(),
|
||||||
|
auth: auth_info.clone(),
|
||||||
|
locks: &config.connect_compute_locks,
|
||||||
|
};
|
||||||
|
|
||||||
|
mech.connect_once(ctx, node_info, compute_config).await
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Some(ProxyPassthrough {
|
||||||
|
client,
|
||||||
|
aux: node.aux.clone(),
|
||||||
|
private_link_id,
|
||||||
|
compute: node,
|
||||||
|
session_id: ctx.session_id(),
|
||||||
|
cancel: session,
|
||||||
|
_req: request_gauge,
|
||||||
|
_conn: conn_gauge,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,34 +3,41 @@ mod tests;
|
|||||||
|
|
||||||
pub(crate) mod retry;
|
pub(crate) mod retry;
|
||||||
pub(crate) mod wake_compute;
|
pub(crate) mod wake_compute;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use futures::FutureExt;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
use smol_str::{SmolStr, format_smolstr};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio::time;
|
||||||
use tracing::{Instrument, debug, error, info, warn};
|
use tracing::{Instrument, debug, error, info, warn};
|
||||||
|
|
||||||
use crate::cancellation::{self, CancellationHandler};
|
use crate::auth::backend::ComputeCredentials;
|
||||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
use crate::cancellation::CancellationHandler;
|
||||||
|
use crate::compute::{AuthInfo, COULD_NOT_CONNECT};
|
||||||
|
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
|
||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
|
use crate::control_plane::errors::WakeComputeError;
|
||||||
|
use crate::control_plane::{CachedNodeInfo, NodeInfo};
|
||||||
use crate::error::{ReportableError, UserFacingError};
|
use crate::error::{ReportableError, UserFacingError};
|
||||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
use crate::metrics::{
|
||||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||||
|
};
|
||||||
|
use crate::pglb::connect_compute::{ComputeConnectBackend, ConnectMechanism};
|
||||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
use crate::pglb::{ClientMode, ClientRequestError};
|
||||||
use crate::pglb::passthrough::ProxyPassthrough;
|
|
||||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
||||||
|
use crate::proxy::wake_compute::wake_compute;
|
||||||
use crate::rate_limiter::EndpointRateLimiter;
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
use crate::stream::{PqStream, Stream};
|
use crate::stream::{PqStream, Stream};
|
||||||
use crate::types::EndpointCacheKey;
|
use crate::types::EndpointCacheKey;
|
||||||
use crate::{auth, compute};
|
use crate::{auth, compute, control_plane};
|
||||||
|
|
||||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||||
|
|
||||||
@@ -46,283 +53,29 @@ impl ReportableError for TlsRequired {
|
|||||||
|
|
||||||
impl UserFacingError for TlsRequired {}
|
impl UserFacingError for TlsRequired {}
|
||||||
|
|
||||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
|
||||||
f: F,
|
|
||||||
cancellation_token: &CancellationToken,
|
|
||||||
) -> Option<F::Output> {
|
|
||||||
match futures::future::select(
|
|
||||||
std::pin::pin!(f),
|
|
||||||
std::pin::pin!(cancellation_token.cancelled()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
futures::future::Either::Left((f, _)) => Some(f),
|
|
||||||
futures::future::Either::Right(((), _)) => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn task_main(
|
|
||||||
config: &'static ProxyConfig,
|
|
||||||
auth_backend: &'static auth::Backend<'static, ()>,
|
|
||||||
listener: tokio::net::TcpListener,
|
|
||||||
cancellation_token: CancellationToken,
|
|
||||||
cancellation_handler: Arc<CancellationHandler>,
|
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
scopeguard::defer! {
|
|
||||||
info!("proxy has shut down");
|
|
||||||
}
|
|
||||||
|
|
||||||
// When set for the server socket, the keepalive setting
|
|
||||||
// will be inherited by all accepted client sockets.
|
|
||||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
|
||||||
|
|
||||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
|
||||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
|
||||||
|
|
||||||
while let Some(accept_result) =
|
|
||||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
|
||||||
{
|
|
||||||
let (socket, peer_addr) = accept_result?;
|
|
||||||
|
|
||||||
let conn_gauge = Metrics::get()
|
|
||||||
.proxy
|
|
||||||
.client_connections
|
|
||||||
.guard(crate::metrics::Protocol::Tcp);
|
|
||||||
|
|
||||||
let session_id = uuid::Uuid::new_v4();
|
|
||||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
|
||||||
let cancellations = cancellations.clone();
|
|
||||||
|
|
||||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
|
||||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
|
||||||
|
|
||||||
connections.spawn(async move {
|
|
||||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
|
||||||
ProxyProtocolV2::Required => {
|
|
||||||
match read_proxy_protocol(socket).await {
|
|
||||||
Err(e) => {
|
|
||||||
warn!("per-client task finished with an error: {e:#}");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// our load balancers will not send any more data. let's just exit immediately
|
|
||||||
Ok((_socket, ConnectHeader::Local)) => {
|
|
||||||
debug!("healthcheck received");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
|
||||||
// error later.
|
|
||||||
ProxyProtocolV2::Rejected => (
|
|
||||||
socket,
|
|
||||||
ConnectionInfo {
|
|
||||||
addr: peer_addr,
|
|
||||||
extra: None,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
match socket.set_nodelay(true) {
|
|
||||||
Ok(()) => {}
|
|
||||||
Err(e) => {
|
|
||||||
error!(
|
|
||||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let ctx = RequestContext::new(
|
|
||||||
session_id,
|
|
||||||
conn_info,
|
|
||||||
crate::metrics::Protocol::Tcp,
|
|
||||||
&config.region,
|
|
||||||
);
|
|
||||||
|
|
||||||
let res = handle_client(
|
|
||||||
config,
|
|
||||||
auth_backend,
|
|
||||||
&ctx,
|
|
||||||
cancellation_handler,
|
|
||||||
socket,
|
|
||||||
ClientMode::Tcp,
|
|
||||||
endpoint_rate_limiter2,
|
|
||||||
conn_gauge,
|
|
||||||
cancellations,
|
|
||||||
)
|
|
||||||
.instrument(ctx.span())
|
|
||||||
.boxed()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match res {
|
|
||||||
Err(e) => {
|
|
||||||
ctx.set_error_kind(e.get_error_kind());
|
|
||||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
|
||||||
}
|
|
||||||
Ok(None) => {
|
|
||||||
ctx.set_success();
|
|
||||||
}
|
|
||||||
Ok(Some(p)) => {
|
|
||||||
ctx.set_success();
|
|
||||||
let _disconnect = ctx.log_connect();
|
|
||||||
match p.proxy_pass(&config.connect_to_compute).await {
|
|
||||||
Ok(()) => {}
|
|
||||||
Err(ErrorSource::Client(e)) => {
|
|
||||||
warn!(
|
|
||||||
?session_id,
|
|
||||||
"per-client task finished with an IO error from the client: {e:#}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(ErrorSource::Compute(e)) => {
|
|
||||||
error!(
|
|
||||||
?session_id,
|
|
||||||
"per-client task finished with an IO error from the compute: {e:#}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
connections.close();
|
|
||||||
cancellations.close();
|
|
||||||
drop(listener);
|
|
||||||
|
|
||||||
// Drain connections
|
|
||||||
connections.wait().await;
|
|
||||||
cancellations.wait().await;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) enum ClientMode {
|
|
||||||
Tcp,
|
|
||||||
Websockets { hostname: Option<String> },
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Abstracts the logic of handling TCP vs WS clients
|
|
||||||
impl ClientMode {
|
|
||||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
ClientMode::Tcp => false,
|
|
||||||
ClientMode::Websockets { .. } => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
|
||||||
match self {
|
|
||||||
ClientMode::Tcp => s.sni_hostname(),
|
|
||||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
|
||||||
match self {
|
|
||||||
ClientMode::Tcp => tls,
|
|
||||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
|
||||||
ClientMode::Websockets { .. } => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
|
||||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
|
||||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
|
||||||
// we cannot be sure the client even understands our error message
|
|
||||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
|
||||||
pub(crate) enum ClientRequestError {
|
|
||||||
#[error("{0}")]
|
|
||||||
Cancellation(#[from] cancellation::CancelError),
|
|
||||||
#[error("{0}")]
|
|
||||||
Handshake(#[from] HandshakeError),
|
|
||||||
#[error("{0}")]
|
|
||||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
|
||||||
#[error("{0}")]
|
|
||||||
PrepareClient(#[from] std::io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
ReportedError(#[from] crate::stream::ReportedError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ReportableError for ClientRequestError {
|
|
||||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
|
||||||
match self {
|
|
||||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
|
||||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
|
||||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
|
||||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
|
||||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
pub(crate) async fn handle_connect_request<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
C: AsyncFnMut(
|
||||||
|
&RequestContext,
|
||||||
|
&CachedNodeInfo,
|
||||||
|
&AuthInfo,
|
||||||
|
&ComputeCredentials,
|
||||||
|
&ComputeConfig,
|
||||||
|
) -> Result<compute::PostgresConnection, compute::ConnectionError>,
|
||||||
|
>(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
auth_backend: &'static auth::Backend<'static, ()>,
|
auth_backend: &'static auth::Backend<'static, ()>,
|
||||||
ctx: &RequestContext,
|
ctx: &RequestContext,
|
||||||
cancellation_handler: Arc<CancellationHandler>,
|
cancellation_handler: Arc<CancellationHandler>,
|
||||||
stream: S,
|
mut client: PqStream<Stream<S>>,
|
||||||
mode: ClientMode,
|
mode: ClientMode,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
params: &StartupMessageParams,
|
||||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
hostname: Option<&str>,
|
||||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
common_names: Option<&HashSet<String>>,
|
||||||
debug!(
|
connect_compute_fn: C,
|
||||||
protocol = %ctx.protocol(),
|
) -> Result<Stream<S>, ClientRequestError> {
|
||||||
"handling interactive connection from client"
|
|
||||||
);
|
|
||||||
|
|
||||||
let metrics = &Metrics::get().proxy;
|
|
||||||
let proto = ctx.protocol();
|
|
||||||
let request_gauge = metrics.connection_requests.guard(proto);
|
|
||||||
|
|
||||||
let tls = config.tls_config.load();
|
|
||||||
let tls = tls.as_deref();
|
|
||||||
|
|
||||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
|
||||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
|
||||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
|
||||||
|
|
||||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
|
||||||
.await??
|
|
||||||
{
|
|
||||||
HandshakeData::Startup(stream, params) => (stream, params),
|
|
||||||
HandshakeData::Cancel(cancel_key_data) => {
|
|
||||||
// spawn a task to cancel the session, but don't wait for it
|
|
||||||
cancellations.spawn({
|
|
||||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
|
||||||
let ctx = ctx.clone();
|
|
||||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
|
||||||
cancel_span.follows_from(tracing::Span::current());
|
|
||||||
async move {
|
|
||||||
cancellation_handler_clone
|
|
||||||
.cancel_session(
|
|
||||||
cancel_key_data,
|
|
||||||
ctx,
|
|
||||||
config.authentication_config.ip_allowlist_check_enabled,
|
|
||||||
config.authentication_config.is_vpc_acccess_proxy,
|
|
||||||
auth_backend.get_api(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
|
||||||
}.instrument(cancel_span)
|
|
||||||
});
|
|
||||||
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
drop(pause);
|
|
||||||
|
|
||||||
ctx.set_db_options(params.clone());
|
|
||||||
|
|
||||||
let hostname = mode.hostname(stream.get_ref());
|
|
||||||
|
|
||||||
let common_names = tls.map(|tls| &tls.common_names);
|
|
||||||
|
|
||||||
// Extract credentials which we're going to use for auth.
|
// Extract credentials which we're going to use for auth.
|
||||||
let result = auth_backend
|
let result = auth_backend
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@@ -331,14 +84,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
|||||||
|
|
||||||
let user_info = match result {
|
let user_info = match result {
|
||||||
Ok(user_info) => user_info,
|
Ok(user_info) => user_info,
|
||||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let user = user_info.get_user().to_owned();
|
let user = user_info.get_user().to_owned();
|
||||||
let user_info = match user_info
|
let user_info = match user_info
|
||||||
.authenticate(
|
.authenticate(
|
||||||
ctx,
|
ctx,
|
||||||
&mut stream,
|
&mut client,
|
||||||
mode.allow_cleartext(),
|
mode.allow_cleartext(),
|
||||||
&config.authentication_config,
|
&config.authentication_config,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
@@ -351,7 +104,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
|||||||
let app = params.get("application_name");
|
let app = params.get("application_name");
|
||||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||||
|
|
||||||
return Err(stream
|
return Err(client
|
||||||
.throw_error(e, Some(ctx))
|
.throw_error(e, Some(ctx))
|
||||||
.instrument(params_span)
|
.instrument(params_span)
|
||||||
.await)?;
|
.await)?;
|
||||||
@@ -366,14 +119,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
|||||||
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
|
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
|
||||||
auth_info.set_startup_params(¶ms, params_compat);
|
auth_info.set_startup_params(¶ms, params_compat);
|
||||||
|
|
||||||
let res = connect_to_compute(
|
let res = connect_to_compute_pglb(
|
||||||
ctx,
|
ctx,
|
||||||
&TcpMechanism {
|
connect_compute_fn,
|
||||||
user_info: creds.info.clone(),
|
|
||||||
auth: auth_info,
|
|
||||||
locks: &config.connect_compute_locks,
|
|
||||||
},
|
|
||||||
&user_info,
|
&user_info,
|
||||||
|
&auth_info,
|
||||||
|
&creds,
|
||||||
config.wake_compute_retry_config,
|
config.wake_compute_retry_config,
|
||||||
&config.connect_to_compute,
|
&config.connect_to_compute,
|
||||||
)
|
)
|
||||||
@@ -381,32 +132,248 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
|||||||
|
|
||||||
let node = match res {
|
let node = match res {
|
||||||
Ok(node) => node,
|
Ok(node) => node,
|
||||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||||
let session = cancellation_handler_clone.get_key();
|
let session = cancellation_handler_clone.get_key();
|
||||||
|
|
||||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
prepare_client_connection(&node, *session.key(), &mut client);
|
||||||
let stream = stream.flush_and_into_inner().await?;
|
let client = client.flush_and_into_inner().await?;
|
||||||
|
|
||||||
let private_link_id = match ctx.extra() {
|
Ok(client)
|
||||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
}
|
||||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
|
||||||
None => None,
|
/// If we couldn't connect, a cached connection info might be to blame
|
||||||
|
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||||
|
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
|
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
|
||||||
|
let is_cached = node_info.cached();
|
||||||
|
if is_cached {
|
||||||
|
warn!("invalidating stalled compute node info cache entry");
|
||||||
|
}
|
||||||
|
let label = if is_cached {
|
||||||
|
ConnectionFailureKind::ComputeCached
|
||||||
|
} else {
|
||||||
|
ConnectionFailureKind::ComputeUncached
|
||||||
|
};
|
||||||
|
Metrics::get().proxy.connection_failures_total.inc(label);
|
||||||
|
|
||||||
|
node_info.invalidate()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
|
pub(crate) async fn connect_to_compute_pglb<
|
||||||
|
C: AsyncFnMut(
|
||||||
|
&RequestContext,
|
||||||
|
&CachedNodeInfo,
|
||||||
|
&AuthInfo,
|
||||||
|
&ComputeCredentials,
|
||||||
|
&ComputeConfig,
|
||||||
|
) -> Result<compute::PostgresConnection, compute::ConnectionError>,
|
||||||
|
B: ComputeConnectBackend,
|
||||||
|
>(
|
||||||
|
ctx: &RequestContext,
|
||||||
|
mut connect_compute_fn: C,
|
||||||
|
user_info: &B,
|
||||||
|
auth_info: &AuthInfo,
|
||||||
|
creds: &ComputeCredentials,
|
||||||
|
wake_compute_retry_config: RetryConfig,
|
||||||
|
compute: &ComputeConfig,
|
||||||
|
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||||
|
let mut num_retries = 0;
|
||||||
|
let node_info =
|
||||||
|
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||||
|
|
||||||
|
// try once
|
||||||
|
let err = match connect_compute_fn(ctx, &node_info, &auth_info, creds, compute).await {
|
||||||
|
Ok(res) => {
|
||||||
|
ctx.success();
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Success,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
Err(e) => e,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(ProxyPassthrough {
|
debug!(error = ?err, COULD_NOT_CONNECT);
|
||||||
client: stream,
|
|
||||||
aux: node.aux.clone(),
|
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||||
private_link_id,
|
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||||
compute: node,
|
// Do not need to retrieve a new node_info, just return the old one.
|
||||||
session_id: ctx.session_id(),
|
if should_retry(&err, num_retries, compute.retry) {
|
||||||
cancel: session,
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
_req: request_gauge,
|
RetriesMetricGroup {
|
||||||
_conn: conn_gauge,
|
outcome: ConnectOutcome::Failed,
|
||||||
}))
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Err(err.into());
|
||||||
|
}
|
||||||
|
node_info
|
||||||
|
} else {
|
||||||
|
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||||
|
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||||
|
invalidate_cache(node_info);
|
||||||
|
// TODO: increment num_retries?
|
||||||
|
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
// now that we have a new node, try connect to it repeatedly.
|
||||||
|
// this can error for a few reasons, for instance:
|
||||||
|
// * DNS connection settings haven't quite propagated yet
|
||||||
|
debug!("wake_compute success. attempting to connect");
|
||||||
|
num_retries = 1;
|
||||||
|
loop {
|
||||||
|
match connect_compute_fn(ctx, &node_info, &auth_info, creds, compute).await {
|
||||||
|
Ok(res) => {
|
||||||
|
ctx.success();
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Success,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
// TODO: is this necessary? We have a metric.
|
||||||
|
info!(?num_retries, "connected to compute node after");
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if !should_retry(&e, num_retries, compute.retry) {
|
||||||
|
// Don't log an error here, caller will print the error
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Failed,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let wait_duration = retry_after(num_retries, compute.retry);
|
||||||
|
num_retries += 1;
|
||||||
|
|
||||||
|
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||||
|
time::sleep(wait_duration).await;
|
||||||
|
drop(pause);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to connect to the compute node, retrying if necessary.
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
|
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||||
|
ctx: &RequestContext,
|
||||||
|
mechanism: &M,
|
||||||
|
user_info: &B,
|
||||||
|
wake_compute_retry_config: RetryConfig,
|
||||||
|
compute: &ComputeConfig,
|
||||||
|
) -> Result<M::Connection, M::Error>
|
||||||
|
where
|
||||||
|
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
|
||||||
|
M::Error: From<WakeComputeError>,
|
||||||
|
{
|
||||||
|
let mut num_retries = 0;
|
||||||
|
let node_info =
|
||||||
|
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||||
|
|
||||||
|
// try once
|
||||||
|
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||||
|
Ok(res) => {
|
||||||
|
ctx.success();
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Success,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
Err(e) => e,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(error = ?err, COULD_NOT_CONNECT);
|
||||||
|
|
||||||
|
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||||
|
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||||
|
// Do not need to retrieve a new node_info, just return the old one.
|
||||||
|
if should_retry(&err, num_retries, compute.retry) {
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Failed,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Err(err.into());
|
||||||
|
}
|
||||||
|
node_info
|
||||||
|
} else {
|
||||||
|
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||||
|
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||||
|
invalidate_cache(node_info);
|
||||||
|
// TODO: increment num_retries?
|
||||||
|
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
// now that we have a new node, try connect to it repeatedly.
|
||||||
|
// this can error for a few reasons, for instance:
|
||||||
|
// * DNS connection settings haven't quite propagated yet
|
||||||
|
debug!("wake_compute success. attempting to connect");
|
||||||
|
num_retries = 1;
|
||||||
|
loop {
|
||||||
|
match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||||
|
Ok(res) => {
|
||||||
|
ctx.success();
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Success,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
// TODO: is this necessary? We have a metric.
|
||||||
|
info!(?num_retries, "connected to compute node after");
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if !should_retry(&e, num_retries, compute.retry) {
|
||||||
|
// Don't log an error here, caller will print the error
|
||||||
|
Metrics::get().proxy.retries_metric.observe(
|
||||||
|
RetriesMetricGroup {
|
||||||
|
outcome: ConnectOutcome::Failed,
|
||||||
|
retry_type: RetryType::ConnectToCompute,
|
||||||
|
},
|
||||||
|
num_retries.into(),
|
||||||
|
);
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let wait_duration = retry_after(num_retries, compute.retry);
|
||||||
|
num_retries += 1;
|
||||||
|
|
||||||
|
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||||
|
time::sleep(wait_duration).await;
|
||||||
|
drop(pause);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||||
|
|||||||
@@ -22,12 +22,13 @@ use super::*;
|
|||||||
use crate::auth::backend::{
|
use crate::auth::backend::{
|
||||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||||
};
|
};
|
||||||
use crate::config::{ComputeConfig, RetryConfig};
|
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
|
||||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||||
use crate::error::ErrorKind;
|
use crate::error::ErrorKind;
|
||||||
use crate::pglb::connect_compute::ConnectMechanism;
|
use crate::pglb::connect_compute::ConnectMechanism;
|
||||||
|
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||||
use crate::tls::client_config::compute_client_config_with_certs;
|
use crate::tls::client_config::compute_client_config_with_certs;
|
||||||
use crate::tls::server_config::CertResolver;
|
use crate::tls::server_config::CertResolver;
|
||||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ impl PoolingBackend {
|
|||||||
tracing::Span::current().record("conn_id", display(conn_id));
|
tracing::Span::current().record("conn_id", display(conn_id));
|
||||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||||
crate::pglb::connect_compute::connect_to_compute(
|
crate::proxy::connect_to_compute(
|
||||||
ctx,
|
ctx,
|
||||||
&TokioMechanism {
|
&TokioMechanism {
|
||||||
conn_id,
|
conn_id,
|
||||||
@@ -225,7 +225,7 @@ impl PoolingBackend {
|
|||||||
},
|
},
|
||||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||||
});
|
});
|
||||||
crate::pglb::connect_compute::connect_to_compute(
|
crate::proxy::connect_to_compute(
|
||||||
ctx,
|
ctx,
|
||||||
&HyperMechanism {
|
&HyperMechanism {
|
||||||
conn_id,
|
conn_id,
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ use crate::context::RequestContext;
|
|||||||
use crate::ext::TaskExt;
|
use crate::ext::TaskExt;
|
||||||
use crate::metrics::Metrics;
|
use crate::metrics::Metrics;
|
||||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||||
use crate::proxy::run_until_cancelled;
|
|
||||||
use crate::rate_limiter::EndpointRateLimiter;
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
use crate::serverless::backend::PoolingBackend;
|
use crate::serverless::backend::PoolingBackend;
|
||||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||||
|
use crate::util::run_until_cancelled;
|
||||||
|
|
||||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||||
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
|
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
|
||||||
|
|||||||
@@ -41,10 +41,11 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
|||||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||||
use crate::pqproto::StartupMessageParams;
|
use crate::pqproto::StartupMessageParams;
|
||||||
use crate::proxy::{NeonOptions, run_until_cancelled};
|
use crate::proxy::NeonOptions;
|
||||||
use crate::serverless::backend::HttpConnError;
|
use crate::serverless::backend::HttpConnError;
|
||||||
use crate::types::{DbName, RoleName};
|
use crate::types::{DbName, RoleName};
|
||||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||||
|
use crate::util::run_until_cancelled;
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
|
|||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::error::ReportableError;
|
use crate::error::ReportableError;
|
||||||
use crate::metrics::Metrics;
|
use crate::metrics::Metrics;
|
||||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
use crate::pglb::{ClientMode, ErrorSource, handle_client};
|
||||||
use crate::rate_limiter::EndpointRateLimiter;
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
|
|||||||
16
proxy/src/util.rs
Normal file
16
proxy/src/util.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
pub async fn run_until_cancelled<F: Future>(
|
||||||
|
f: F,
|
||||||
|
cancellation_token: &CancellationToken,
|
||||||
|
) -> Option<F::Output> {
|
||||||
|
match futures::future::select(
|
||||||
|
std::pin::pin!(f),
|
||||||
|
std::pin::pin!(cancellation_token.cancelled()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
futures::future::Either::Left((f, _)) => Some(f),
|
||||||
|
futures::future::Either::Right(((), _)) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user