mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-08 05:00:38 +00:00
Compare commits
15 Commits
conrad/ref
...
conrad/rew
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb02e54843 | ||
|
|
41a5b8524a | ||
|
|
6da7b87a32 | ||
|
|
72b1c573b1 | ||
|
|
b509982bbf | ||
|
|
a78a52acb5 | ||
|
|
3370e8cb00 | ||
|
|
f37a558280 | ||
|
|
744011437a | ||
|
|
a10d26a083 | ||
|
|
aece520365 | ||
|
|
9017811d61 | ||
|
|
551a33aa04 | ||
|
|
95216ae6ec | ||
|
|
a3a10d1839 |
@@ -1,5 +1,3 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
@@ -8,7 +6,7 @@ use crate::tls::MakeTlsConnect;
|
||||
use crate::{Error, cancel_query_raw, connect_socket};
|
||||
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
config: SocketConfig,
|
||||
ssl_mode: SslMode,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
@@ -17,16 +15,6 @@ pub(crate) async fn cancel_query<T>(
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = match config {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"unknown host",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let hostname = match &config.host {
|
||||
Host::Tcp(host) => &**host,
|
||||
};
|
||||
|
||||
@@ -9,9 +9,16 @@ use crate::{Error, cancel_query, cancel_query_raw};
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Clone)]
|
||||
pub struct CancelToken {
|
||||
pub socket_config: Option<SocketConfig>,
|
||||
pub socket_config: SocketConfig,
|
||||
pub raw: RawCancelToken,
|
||||
}
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RawCancelToken {
|
||||
pub ssl_mode: SslMode,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
@@ -36,14 +43,16 @@ impl CancelToken {
|
||||
{
|
||||
cancel_query::cancel_query(
|
||||
self.socket_config.clone(),
|
||||
self.ssl_mode,
|
||||
self.raw.ssl_mode,
|
||||
tls,
|
||||
self.process_id,
|
||||
self.secret_key,
|
||||
self.raw.process_id,
|
||||
self.raw.secret_key,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl RawCancelToken {
|
||||
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
|
||||
/// connection itself.
|
||||
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
|
||||
|
||||
@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::cancel_token::RawCancelToken;
|
||||
use crate::codec::{BackendMessages, FrontendMessage};
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::query::RowStream;
|
||||
@@ -331,10 +332,12 @@ impl Client {
|
||||
/// connection associated with this client.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
CancelToken {
|
||||
socket_config: Some(self.socket_config.clone()),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
socket_config: self.socket_config.clone(),
|
||||
raw: RawCancelToken {
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
|
||||
pub use crate::cancel_token::CancelToken;
|
||||
pub use crate::cancel_token::{CancelToken, RawCancelToken};
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
|
||||
@@ -14,7 +14,7 @@ pub(crate) mod private {
|
||||
|
||||
/// Channel binding information returned from a TLS handshake.
|
||||
pub struct ChannelBinding {
|
||||
pub(crate) tls_server_end_point: Option<Vec<u8>>,
|
||||
pub tls_server_end_point: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl ChannelBinding {
|
||||
|
||||
@@ -7,13 +7,13 @@ use crate::auth::{self, AuthFlow};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::{compute, sasl};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &RequestContext,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::stream::PqStream;
|
||||
use crate::stream::PqFeStream;
|
||||
use crate::types::RoleName;
|
||||
use crate::{auth, compute, waiters};
|
||||
|
||||
@@ -96,7 +96,7 @@ impl ConsoleRedirectBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client)
|
||||
.await
|
||||
@@ -122,7 +122,7 @@ async fn authenticate(
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::stream::{self, Stream};
|
||||
pub(crate) async fn authenticate_cleartext(
|
||||
ctx: &RequestContext,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn authenticate_cleartext(
|
||||
pub(crate) async fn password_hack_no_authentication(
|
||||
ctx: &RequestContext,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<(ComputeUserInfo, Vec<u8>)> {
|
||||
debug!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
@@ -201,7 +201,7 @@ async fn auth_quirks(
|
||||
ctx: &RequestContext,
|
||||
api: &impl control_plane::ControlPlaneApi,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -267,7 +267,7 @@ async fn authenticate_with_secret(
|
||||
ctx: &RequestContext,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -318,7 +318,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
ctx: &RequestContext,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -446,7 +446,7 @@ mod tests {
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
@@ -522,7 +522,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -604,7 +604,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -658,7 +658,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram::{self};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
@@ -53,7 +53,7 @@ pub(crate) struct CleartextPassword {
|
||||
#[must_use]
|
||||
pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
stream: &'a mut PqFeStream<Stream<S>>,
|
||||
/// State might contain ancillary data.
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
@@ -62,7 +62,7 @@ pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqFeStream<Stream<S>>, method: M) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
|
||||
145
proxy/src/batch.rs
Normal file
145
proxy/src/batch.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
//! Batch processing system based on intrusive linked lists.
|
||||
//!
|
||||
//! Enqueuing a batch job requires no allocations, with
|
||||
//! direct support for cancelling jobs early.
|
||||
use std::collections::BTreeMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use futures::future::Either;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
|
||||
use crate::ext::LockExt;
|
||||
|
||||
pub trait QueueProcessing: Send + 'static {
|
||||
type Req: Send + 'static;
|
||||
type Res: Send;
|
||||
|
||||
/// Get the desired batch size.
|
||||
fn batch_size(&self, queue_size: usize) -> usize;
|
||||
|
||||
/// This applies a full batch of events.
|
||||
/// Must respond with a full batch of replies.
|
||||
///
|
||||
/// If this apply can error, it's expected that errors be forwarded to each Self::Res.
|
||||
///
|
||||
/// Batching does not need to happen atomically.
|
||||
fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
|
||||
}
|
||||
|
||||
pub struct BatchQueue<P: QueueProcessing> {
|
||||
processor: tokio::sync::Mutex<P>,
|
||||
inner: Mutex<BatchQueueInner<P>>,
|
||||
}
|
||||
|
||||
struct BatchJob<P: QueueProcessing> {
|
||||
req: P::Req,
|
||||
res: tokio::sync::oneshot::Sender<P::Res>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueue<P> {
|
||||
pub fn new(p: P) -> Self {
|
||||
Self {
|
||||
processor: tokio::sync::Mutex::new(p),
|
||||
inner: Mutex::new(BatchQueueInner {
|
||||
version: 0,
|
||||
queue: BTreeMap::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call(&self, req: P::Req) -> P::Res {
|
||||
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
|
||||
let guard = scopeguard::guard(id, move |id| {
|
||||
if self.inner.lock_propagate_poison().queue.remove(&id).is_some() {
|
||||
tracing::debug!("batched task cancelled before completion");
|
||||
}
|
||||
});
|
||||
|
||||
let resp = loop {
|
||||
// try become the leader, or try wait for success.
|
||||
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await
|
||||
{
|
||||
// we got the resp.
|
||||
Either::Left((resp, _)) => break resp.ok(),
|
||||
// we are the leader.
|
||||
Either::Right((p, rx_)) => {
|
||||
rx = rx_;
|
||||
p
|
||||
}
|
||||
};
|
||||
|
||||
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
|
||||
|
||||
// apply a batch.
|
||||
let values = processor.apply(reqs).await;
|
||||
|
||||
// send response values.
|
||||
for (tx, value) in std::iter::zip(resps, values) {
|
||||
// sender hung up but that's fine.
|
||||
drop(tx.send(value));
|
||||
}
|
||||
|
||||
match rx.try_recv() {
|
||||
Ok(resp) => break Some(resp),
|
||||
Err(TryRecvError::Closed) => break None,
|
||||
// edge case - there was a race condition where
|
||||
// we became the leader but were not in the batch.
|
||||
//
|
||||
// Example:
|
||||
// thread 1: register job id=1
|
||||
// thread 2: register job id=2
|
||||
// thread 2: processor.lock().await
|
||||
// thread 1: processor.lock().await
|
||||
// thread 2: becomes leader, batch_size=1, jobs=[1].
|
||||
Err(TryRecvError::Empty) => {}
|
||||
}
|
||||
};
|
||||
|
||||
// already removed.
|
||||
ScopeGuard::into_inner(guard);
|
||||
|
||||
resp.expect("no response found. batch processer should not panic")
|
||||
}
|
||||
}
|
||||
|
||||
struct BatchQueueInner<P: QueueProcessing> {
|
||||
version: u64,
|
||||
queue: BTreeMap<u64, BatchJob<P>>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueueInner<P> {
|
||||
fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let id = self.version;
|
||||
|
||||
// Overflow concern:
|
||||
// This is a u64, and we might enqueue 2^16 tasks per second.
|
||||
// This gives us 2^48 seconds (9 million years).
|
||||
// Even if this does overflow, it will not break, but some
|
||||
// jobs with the higher version might never get prioritised.
|
||||
self.version += 1;
|
||||
|
||||
self.queue.insert(id, BatchJob { req, res: tx });
|
||||
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
|
||||
let batch_size = p.batch_size(self.queue.len());
|
||||
let mut reqs = Vec::with_capacity(batch_size);
|
||||
let mut resps = Vec::with_capacity(batch_size);
|
||||
|
||||
while reqs.len() < batch_size {
|
||||
let Some((_, job)) = self.queue.pop_first() else {
|
||||
break;
|
||||
};
|
||||
reqs.push(job.req);
|
||||
resps.push(job.res);
|
||||
}
|
||||
|
||||
(reqs, resps)
|
||||
}
|
||||
}
|
||||
@@ -201,7 +201,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
auth_backend,
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute)),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
@@ -262,7 +262,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<TlsStream<S>> {
|
||||
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
let (mut stream, msg) = PqFeStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct: None } => {
|
||||
let raw = stream.accept_tls().await?;
|
||||
|
||||
@@ -21,7 +21,8 @@ use utils::{project_build_tag, project_git_version};
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
|
||||
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
|
||||
use crate::batch::BatchQueue;
|
||||
use crate::cancellation::{CancellationHandler, CancellationProcessor};
|
||||
use crate::config::{
|
||||
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
|
||||
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
|
||||
@@ -390,13 +391,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
|
||||
@@ -523,14 +518,10 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(redis_kv_client);
|
||||
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
|
||||
client: redis_kv_client,
|
||||
batch_size: args.cancellation_batch_size,
|
||||
}));
|
||||
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use anyhow::anyhow;
|
||||
use futures::FutureExt;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::CancelToken;
|
||||
use postgres_client::RawCancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use redis::{Cmd, FromRedisValue, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::auth::AuthError;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::batch::{BatchQueue, QueueProcessing};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::ControlPlaneApi;
|
||||
@@ -27,46 +31,36 @@ use crate::redis::kv_ops::RedisKVClient;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
|
||||
const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
|
||||
const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelKeyOp {
|
||||
StoreCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
value: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
expire: i64, // TTL for key
|
||||
key: CancelKeyData,
|
||||
value: Box<str>,
|
||||
expire: std::time::Duration,
|
||||
},
|
||||
GetCancelData {
|
||||
key: String,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
key: CancelKeyData,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Pipeline {
|
||||
inner: redis::Pipeline,
|
||||
replies: Vec<CancelReplyOp>,
|
||||
replies: usize,
|
||||
}
|
||||
|
||||
impl Pipeline {
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self {
|
||||
inner: redis::Pipeline::with_capacity(n),
|
||||
replies: Vec::with_capacity(n),
|
||||
replies: 0,
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(&mut self, client: &mut RedisKVClient) {
|
||||
let responses = self.replies.len();
|
||||
async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
|
||||
let responses = self.replies;
|
||||
let batch_size = self.inner.len();
|
||||
|
||||
match client.query(&self.inner).await {
|
||||
@@ -76,176 +70,73 @@ impl Pipeline {
|
||||
batch_size,
|
||||
responses, "successfully completed cancellation jobs",
|
||||
);
|
||||
for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
|
||||
reply.send_value(value);
|
||||
}
|
||||
values.into_iter().map(Ok).collect()
|
||||
}
|
||||
Ok(value) => {
|
||||
error!(batch_size, ?value, "unexpected redis return value");
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("incorrect response type from redis"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
Err(err) => {
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
self.inner.clear();
|
||||
self.replies.clear();
|
||||
}
|
||||
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd);
|
||||
self.replies.push(reply);
|
||||
self.replies += 1;
|
||||
}
|
||||
|
||||
fn add_command_no_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd).ignore();
|
||||
}
|
||||
|
||||
fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
|
||||
match reply {
|
||||
Some(reply) => self.add_command_with_reply(cmd, reply),
|
||||
None => self.add_command_no_reply(cmd),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelKeyOp {
|
||||
fn register(self, pipe: &mut Pipeline) {
|
||||
fn register(&self, pipe: &mut Pipeline) {
|
||||
#[allow(clippy::used_underscore_binding)]
|
||||
match self {
|
||||
CancelKeyOp::StoreCancelKey {
|
||||
key,
|
||||
field,
|
||||
value,
|
||||
resp_tx,
|
||||
_guard,
|
||||
expire,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hset(&key, field, value), reply);
|
||||
pipe.add_command_no_reply(Cmd::expire(key, expire));
|
||||
CancelKeyOp::StoreCancelKey { key, value, expire } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
|
||||
pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
|
||||
}
|
||||
CancelKeyOp::GetCancelData {
|
||||
key,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
|
||||
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
|
||||
}
|
||||
CancelKeyOp::RemoveCancelKey {
|
||||
key,
|
||||
field,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hdel(key, field), reply);
|
||||
CancelKeyOp::GetCancelData { key } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hget(key, "data"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelReplyOp {
|
||||
StoreCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
GetCancelData {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
pub struct CancellationProcessor {
|
||||
pub client: RedisKVClient,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl CancelReplyOp {
|
||||
fn send_err(self, e: anyhow::Error) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
impl QueueProcessing for CancellationProcessor {
|
||||
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
|
||||
type Res = anyhow::Result<redis::Value>;
|
||||
|
||||
fn batch_size(&self, _queue_size: usize) -> usize {
|
||||
self.batch_size
|
||||
}
|
||||
|
||||
fn send_value(self, v: redis::Value) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Running as a separate task to accept messages through the rx channel
|
||||
pub async fn handle_cancel_messages(
|
||||
client: &mut RedisKVClient,
|
||||
mut rx: mpsc::Receiver<CancelKeyOp>,
|
||||
batch_size: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
let mut pipeline = Pipeline::with_capacity(batch_size);
|
||||
|
||||
loop {
|
||||
if rx.recv_many(&mut batch, batch_size).await == 0 {
|
||||
warn!("shutting down cancellation queue");
|
||||
break Ok(());
|
||||
}
|
||||
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
|
||||
let mut pipeline = Pipeline::with_capacity(batch.len());
|
||||
|
||||
let batch_size = batch.len();
|
||||
debug!(batch_size, "running cancellation jobs");
|
||||
|
||||
for msg in batch.drain(..) {
|
||||
msg.register(&mut pipeline);
|
||||
for (_, op) in &batch {
|
||||
op.register(&mut pipeline);
|
||||
}
|
||||
|
||||
pipeline.execute(client).await;
|
||||
pipeline.execute(&mut self.client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +147,7 @@ pub struct CancellationHandler {
|
||||
compute_config: &'static ComputeConfig,
|
||||
// rate limiter of cancellation requests
|
||||
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
|
||||
tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -296,13 +187,10 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
|
||||
impl CancellationHandler {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>,
|
||||
) -> Self {
|
||||
pub fn new(compute_config: &'static ComputeConfig) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
tx,
|
||||
tx: OnceLock::new(),
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
@@ -312,7 +200,14 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
|
||||
self.tx
|
||||
.set(queue)
|
||||
.map_err(|_| {})
|
||||
.expect("cancellation queue should be registered once");
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -322,14 +217,10 @@ impl CancellationHandler {
|
||||
|
||||
let key: CancelKeyData = rand::random();
|
||||
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
|
||||
debug!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,62 +228,43 @@ impl CancellationHandler {
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
) -> Result<Option<CancelClosure>, CancelError> {
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGet);
|
||||
let op = CancelKeyOp::GetCancelData { key };
|
||||
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
let op = CancelKeyOp::GetCancelData {
|
||||
key: redis_key,
|
||||
resp_tx,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGetAll),
|
||||
};
|
||||
|
||||
let Some(tx) = &self.tx else {
|
||||
let Some(tx) = self.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
tx.try_send(op)
|
||||
const TIMEOUT: Duration = Duration::from_secs(5);
|
||||
let result = timeout(TIMEOUT, tx.call((guard, op)))
|
||||
.await
|
||||
.map_err(|_| {
|
||||
tracing::warn!("timed out waiting to receive GetCancelData response");
|
||||
CancelError::RateLimit
|
||||
})?
|
||||
.map_err(|e| {
|
||||
tracing::warn!("failed to send GetCancelData for {key}: {e}");
|
||||
})
|
||||
.map_err(|()| CancelError::InternalError)?;
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let result = resp_rx.await.map_err(|e| {
|
||||
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state_str: Option<String> = match result {
|
||||
Ok(mut state) => {
|
||||
if state.len() == 1 {
|
||||
Some(state.remove(0).1)
|
||||
} else {
|
||||
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to receive cancel state from redis: {e}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
};
|
||||
let cancel_closure: CancelClosure =
|
||||
serde_json::from_str(&cancel_state_str).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state: Option<CancelClosure> = match cancel_state_str {
|
||||
Some(state) => {
|
||||
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
Some(cancel_closure)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(cancel_state)
|
||||
Ok(Some(cancel_closure))
|
||||
}
|
||||
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
/// check_allowed - if true, check if the IP is allowed to cancel the query.
|
||||
@@ -467,10 +339,10 @@ impl CancellationHandler {
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String, // for pg_sni router
|
||||
user_info: ComputeUserInfo,
|
||||
}
|
||||
@@ -478,7 +350,7 @@ pub struct CancelClosure {
|
||||
impl CancelClosure {
|
||||
pub(crate) fn new(
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Self {
|
||||
@@ -491,7 +363,7 @@ impl CancelClosure {
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(
|
||||
self,
|
||||
&self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
@@ -512,7 +384,6 @@ impl CancelClosure {
|
||||
pub(crate) struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
redis_key: String,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
@@ -521,60 +392,66 @@ impl Session {
|
||||
&self.key
|
||||
}
|
||||
|
||||
// Send the store key op to the cancellation handler and set TTL for the key
|
||||
pub(crate) fn write_cancel_key(
|
||||
/// Ensure the cancel key is continously refreshed,
|
||||
/// but stop when the channel is dropped.
|
||||
pub(crate) async fn maintain_cancel_key(
|
||||
&self,
|
||||
cancel_closure: CancelClosure,
|
||||
) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
session_id: uuid::Uuid,
|
||||
cancel: tokio::sync::oneshot::Receiver<Infallible>,
|
||||
cancel_closure: &CancelClosure,
|
||||
compute_config: &ComputeConfig,
|
||||
) {
|
||||
futures::future::select(
|
||||
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
|
||||
cancel,
|
||||
)
|
||||
.await;
|
||||
|
||||
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
|
||||
tracing::warn!("failed to serialize cancel closure: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
value: closure_json,
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HSet),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
if let Err(err) = cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
?session_id,
|
||||
?err,
|
||||
"could not cancel the query in the database"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
// Ensure the cancel key is continously refreshed.
|
||||
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
|
||||
let Some(tx) = self.cancellation_handler.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
// don't exit, as we only want to exit if cancelled externally.
|
||||
std::future::pending().await
|
||||
};
|
||||
|
||||
let op = CancelKeyOp::RemoveCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
let closure_json = serde_json::to_string(&cancel_closure)
|
||||
.expect("serialising to json string should not fail")
|
||||
.into_boxed_str();
|
||||
|
||||
loop {
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HDel),
|
||||
};
|
||||
.guard(RedisMsgKind::HSet);
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.key,
|
||||
value: closure_json.clone(),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"registering cancellation key"
|
||||
);
|
||||
|
||||
if tx.call((guard, op)).await.is_ok() {
|
||||
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
146
proxy/src/compute/authenticate.rs
Normal file
146
proxy/src/compute/authenticate.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use bytes::BufMut;
|
||||
use postgres_client::tls::{ChannelBinding, TlsStream};
|
||||
use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::{Auth, MaybeRustlsStream};
|
||||
use crate::compute::RustlsStream;
|
||||
use crate::pqproto::{
|
||||
AUTH_OK, AUTH_SASL, AUTH_SASL_CONT, AUTH_SASL_FINAL, FE_PASSWORD_MESSAGE, StartupMessageParams,
|
||||
};
|
||||
use crate::stream::{PostgresError, PqBeStream};
|
||||
|
||||
pub async fn authenticate<S>(
|
||||
stream: MaybeRustlsStream<S>,
|
||||
auth: Option<&Auth>,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<PqBeStream<MaybeRustlsStream<S>>, PostgresError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
RustlsStream<S>: TlsStream + Unpin,
|
||||
{
|
||||
let mut stream = PqBeStream::new(stream, params);
|
||||
stream.flush().await?;
|
||||
|
||||
let channel_binding = stream.get_ref().channel_binding();
|
||||
|
||||
// TODO: rather than checking for SASL, maybe we can just assume it.
|
||||
// With SCRAM_SHA_256 if we're not using TLS,
|
||||
// and SCRAM_SHA_256_PLUS if we are using TLS.
|
||||
|
||||
let (channel_binding, mechanism) = match stream.read_auth_message().await? {
|
||||
(AUTH_OK, _) => return Ok(stream),
|
||||
(AUTH_SASL, mechanisms) => {
|
||||
let mut has_scram = false;
|
||||
let mut has_scram_plus = false;
|
||||
for mechanism in mechanisms.split(|&b| b == b'\0') {
|
||||
match mechanism {
|
||||
b"SCRAM-SHA-256" => has_scram = true,
|
||||
b"SCRAM-SHA-256-PLUS" => has_scram_plus = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
match (channel_binding, has_scram, has_scram_plus) {
|
||||
(cb, true, false) => {
|
||||
if cb.tls_server_end_point.is_some() {
|
||||
// I don't think this can happen in our setup, but I would like to monitor it.
|
||||
tracing::warn!(
|
||||
"TLS is enabled, but compute doesn't support SCRAM-SHA-256-PLUS."
|
||||
);
|
||||
}
|
||||
(sasl::ChannelBinding::unrequested(), SCRAM_SHA_256)
|
||||
}
|
||||
(
|
||||
ChannelBinding {
|
||||
tls_server_end_point: None,
|
||||
},
|
||||
true,
|
||||
_,
|
||||
) => (sasl::ChannelBinding::unsupported(), SCRAM_SHA_256),
|
||||
(
|
||||
ChannelBinding {
|
||||
tls_server_end_point: Some(h),
|
||||
},
|
||||
_,
|
||||
true,
|
||||
) => (
|
||||
sasl::ChannelBinding::tls_server_end_point(h),
|
||||
SCRAM_SHA_256_PLUS,
|
||||
),
|
||||
(_, false, _) => {
|
||||
tracing::error!(
|
||||
"compute responded with unsupported auth mechanisms: {}",
|
||||
String::from_utf8_lossy(mechanisms)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
};
|
||||
|
||||
let mut scram = match auth {
|
||||
// We only touch passwords when it comes to console-redirect.
|
||||
Some(Auth::Password(pw)) => sasl::ScramSha256::new(pw, channel_binding),
|
||||
Some(Auth::Scram(keys)) => sasl::ScramSha256::new_with_keys(**keys, channel_binding),
|
||||
None => {
|
||||
// local_proxy does not set credentials, since it relies on trust and expects an OK message above
|
||||
tracing::error!("compute requested SASL auth, but there are no credentials available",);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
};
|
||||
|
||||
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
|
||||
buf.put_slice(mechanism.as_bytes());
|
||||
buf.put_u8(b'\0');
|
||||
|
||||
let data = scram.message();
|
||||
buf.put_u32(data.len() as u32);
|
||||
buf.put_slice(data);
|
||||
});
|
||||
stream.flush().await?;
|
||||
|
||||
loop {
|
||||
// wait for SASLContinue or SASLFinal.
|
||||
match stream.read_auth_message().await? {
|
||||
(AUTH_SASL_CONT, data) => scram.update(data).await?,
|
||||
(AUTH_SASL_FINAL, data) => {
|
||||
scram.finish(data)?;
|
||||
break;
|
||||
}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
|
||||
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
|
||||
buf.put_slice(scram.message());
|
||||
});
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
match stream.read_auth_message().await? {
|
||||
(AUTH_OK, _) => {}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod authenticate;
|
||||
mod tls;
|
||||
|
||||
use std::fmt::Debug;
|
||||
@@ -9,15 +10,12 @@ use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, NoTls, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::compute::tls::TlsError;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -28,6 +26,7 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::stream::{PostgresError, PqBeStream};
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -37,7 +36,7 @@ pub(crate) enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
Postgres(#[from] PostgresError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] TlsError),
|
||||
@@ -54,20 +53,21 @@ impl UserFacingError for ConnectionError {
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
ConnectionError::Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
ConnectionError::Postgres(PostgresError::Error(err)) => {
|
||||
let (_code, msg) = err.parse();
|
||||
let msg = String::from_utf8_lossy(msg);
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!(
|
||||
"{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
|
||||
)
|
||||
} else {
|
||||
msg.into_owned()
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
}
|
||||
ConnectionError::Postgres(err) => err.to_string(),
|
||||
ConnectionError::WakeComputeError(err) => err.to_string_client(),
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
@@ -80,10 +80,12 @@ impl UserFacingError for ConnectionError {
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::Postgres(PostgresError::Io(_)) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::Postgres(
|
||||
PostgresError::Error(_)
|
||||
| PostgresError::InvalidAuthMessage
|
||||
| PostgresError::Unexpected(_),
|
||||
) => crate::error::ErrorKind::Postgres,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -162,18 +164,6 @@ impl ConnectInfo {
|
||||
}
|
||||
|
||||
impl AuthInfo {
|
||||
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
|
||||
match &self.auth {
|
||||
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
|
||||
Some(Auth::Password(pw)) => config.password(pw),
|
||||
None => &mut config,
|
||||
};
|
||||
for (k, v) in self.server_params.iter() {
|
||||
config.set_param(k, v);
|
||||
}
|
||||
config
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(
|
||||
&mut self,
|
||||
@@ -213,7 +203,7 @@ impl ConnectInfo {
|
||||
async fn connect_raw(
|
||||
&self,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
|
||||
) -> Result<(SocketAddr, MaybeRustlsStream<TcpStream>), TlsError> {
|
||||
let timeout = config.timeout;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
@@ -265,21 +255,19 @@ impl ConnectInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type RustlsStream<S> = <ComputeConfig as MakeTlsConnect<S>>::Stream;
|
||||
pub type MaybeRustlsStream<S> = MaybeTlsStream<S, RustlsStream<S>>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub(crate) cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
/// Notices received from compute after authenticating
|
||||
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub stream: PqBeStream<MaybeRustlsStream<TcpStream>>,
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
pub socket_addr: SocketAddr,
|
||||
pub hostname: String,
|
||||
pub ssl_mode: SslMode,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
pub guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
@@ -287,31 +275,18 @@ impl ConnectInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
aux: &MetricsAuxInfo,
|
||||
auth: &AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
|
||||
// we setup SSL early in `ConnectInfo::connect_raw`.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let connection = tmp_config.connect_raw(stream, NoTls).await?;
|
||||
let stream =
|
||||
authenticate::authenticate(stream, auth.auth.as_ref(), &auth.server_params).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
// tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let MaybeTlsStream::Raw(stream) = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
@@ -323,27 +298,13 @@ impl ConnectInfo {
|
||||
ctx.get_testodrome_id().unwrap_or_default(),
|
||||
);
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(
|
||||
socket_addr,
|
||||
CancelToken {
|
||||
socket_config: None,
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
self.host.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
let connection = PostgresConnection {
|
||||
stream,
|
||||
params: parameters,
|
||||
delayed_notice,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
socket_addr,
|
||||
hostname: self.host.to_string(),
|
||||
ssl_mode: self.ssl_mode,
|
||||
aux: aux.clone(),
|
||||
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -120,7 +120,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
@@ -177,7 +177,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
@@ -210,18 +210,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, mut auth_info, user_info) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.authenticate(ctx, &config.authentication_config, &mut client)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
let node = connect_to_compute(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
@@ -229,25 +228,41 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.or_else(|e| async { Err(client.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
let cancel_closure =
|
||||
prepare_client_connection(&mut node, session.key(), &mut client, user_info).await?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
let compute = node.stream.flush_and_into_inner().await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
client,
|
||||
compute,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -78,11 +78,8 @@ impl NodeInfo {
|
||||
ctx: &RequestContext,
|
||||
auth: &compute::AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.conn_info
|
||||
.connect(ctx, self.aux.clone(), auth, config, user_info)
|
||||
.await
|
||||
self.conn_info.connect(ctx, &self.aux, auth, config).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@
|
||||
pub mod binary;
|
||||
|
||||
mod auth;
|
||||
mod batch;
|
||||
mod cache;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::stream::{PqFeStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@@ -49,7 +49,7 @@ impl ReportableError for HandshakeError {
|
||||
}
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Startup(PqFeStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
let (mut stream, mut msg) = PqFeStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
loop {
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
|
||||
@@ -152,7 +152,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
tls: tls_stream,
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
(stream, msg) = PqFeStream::parse_startup(tls).await?;
|
||||
} else {
|
||||
if direct.is_some() {
|
||||
// client sent us a ClientHello already, we can't do anything with it.
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
use futures::FutureExt;
|
||||
use std::convert::Infallible;
|
||||
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::compute::MaybeRustlsStream;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::metrics::{
|
||||
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
|
||||
NumDbConnectionsGuard,
|
||||
};
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -64,40 +67,19 @@ pub(crate) async fn proxy_pass(
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) compute: MaybeRustlsStream<TcpStream>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,11 +11,86 @@ use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct ErrorCode(pub [u8; 5]);
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct FeTag(pub u8);
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct BeTag(pub u8);
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct AuthTag(pub i32);
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p');
|
||||
|
||||
pub const BE_AUTH_MESSAGE: BeTag = BeTag(b'R');
|
||||
pub const BE_ERR_MESSAGE: BeTag = BeTag(b'E');
|
||||
pub const BE_KEY_MESSAGE: BeTag = BeTag(b'K');
|
||||
pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z');
|
||||
pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v');
|
||||
|
||||
pub const AUTH_OK: AuthTag = AuthTag(0);
|
||||
pub const AUTH_SASL: AuthTag = AuthTag(10);
|
||||
pub const AUTH_SASL_CONT: AuthTag = AuthTag(11);
|
||||
pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12);
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: ErrorCode = ErrorCode(*b"XX000");
|
||||
pub const CONNECTION_EXCEPTION: ErrorCode = ErrorCode(*b"08000");
|
||||
pub const CONNECTION_DOES_NOT_EXIST: ErrorCode = ErrorCode(*b"08003");
|
||||
pub const CONNECTION_FAILURE: ErrorCode = ErrorCode(*b"08006");
|
||||
pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: ErrorCode = ErrorCode(*b"08001");
|
||||
pub const PROTOCOL_VIOLATION: ErrorCode = ErrorCode(*b"08P01");
|
||||
pub const INVALID_PARAMETER_VALUE: ErrorCode = ErrorCode(*b"22023");
|
||||
pub const INVALID_CATALOG_NAME: ErrorCode = ErrorCode(*b"3D000");
|
||||
pub const INVALID_SCHEMA_NAME: ErrorCode = ErrorCode(*b"3F000");
|
||||
pub const T_R_SERIALIZATION_FAILURE: ErrorCode = ErrorCode(*b"40001");
|
||||
pub const SYNTAX_ERROR: ErrorCode = ErrorCode(*b"42601");
|
||||
pub const OUT_OF_MEMORY: ErrorCode = ErrorCode(*b"53200");
|
||||
pub const TOO_MANY_CONNECTIONS: ErrorCode = ErrorCode(*b"53300");
|
||||
|
||||
impl fmt::Display for AuthTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.0 {
|
||||
0 => f.write_str("Ok"),
|
||||
2 => f.write_str("KerberosV5"),
|
||||
3 => f.write_str("CleartextPassword"),
|
||||
5 => f.write_str("MD5Password"),
|
||||
7 => f.write_str("GSS"),
|
||||
8 => f.write_str("GSSContinue"),
|
||||
9 => f.write_str("SSPI"),
|
||||
10 => f.write_str("SASL"),
|
||||
11 => f.write_str("SASLContinue"),
|
||||
12 => f.write_str("SASLFinal"),
|
||||
x => write!(f, "{x}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BeTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
BE_AUTH_MESSAGE => f.write_str("Authentication"),
|
||||
BE_KEY_MESSAGE => f.write_str("BackendKeyData"),
|
||||
BE_ERR_MESSAGE => f.write_str("ErrorResponse"),
|
||||
BE_READY_MESSAGE => f.write_str("ReadyForQuery"),
|
||||
BE_NEGOTIATE_MESSAGE => f.write_str("NegotiateProtocolVersion"),
|
||||
BeTag(b'S') => f.write_str("ParameterStatus"),
|
||||
BeTag(b'N') => f.write_str("NoticeMessage"),
|
||||
BeTag(x) => write!(f, "{:?}", char::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FeTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
FE_PASSWORD_MESSAGE => f.write_str("Password"),
|
||||
FeTag(x) => write!(f, "{:?}", char::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The protocol version number.
|
||||
///
|
||||
@@ -280,6 +355,10 @@ impl WriteBuf {
|
||||
Self(Cursor::new(Vec::new()))
|
||||
}
|
||||
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.get_ref().len()
|
||||
}
|
||||
|
||||
/// Use a heuristic to determine if we should shrink the write buffer.
|
||||
#[inline]
|
||||
fn should_shrink(&self) -> bool {
|
||||
@@ -313,6 +392,19 @@ impl WriteBuf {
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// Write a startup message.
|
||||
pub fn startup(&mut self, params: &StartupMessageParams) {
|
||||
self.0.get_mut().extend_from_slice(
|
||||
StartupHeader {
|
||||
len: big_endian::U32::new(params.params.len() as u32 + 9),
|
||||
version: ProtocolVersion::new(3, 0),
|
||||
}
|
||||
.as_bytes(),
|
||||
);
|
||||
self.0.get_mut().extend_from_slice(params.params.as_bytes());
|
||||
self.0.get_mut().push(0);
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
///
|
||||
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
|
||||
@@ -353,7 +445,7 @@ impl WriteBuf {
|
||||
|
||||
// Code: error_code
|
||||
buf.put_u8(b'C');
|
||||
buf.put_slice(&error_code);
|
||||
buf.put_slice(&error_code.0);
|
||||
buf.put_u8(0);
|
||||
|
||||
// Message: msg
|
||||
@@ -468,11 +560,11 @@ pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||
AuthenticationCleartextPassword,
|
||||
BackendKeyData(CancelKeyData),
|
||||
ParameterStatus {
|
||||
name: &'a [u8],
|
||||
value: &'a [u8],
|
||||
},
|
||||
#[cfg(test)]
|
||||
ReadyForQuery,
|
||||
NoticeResponse(&'a str),
|
||||
NegotiateProtocolVersion {
|
||||
@@ -494,17 +586,17 @@ impl BeMessage<'_> {
|
||||
match self {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
|
||||
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(0));
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationCleartextPassword => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
|
||||
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(3));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
|
||||
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
|
||||
buf.write_raw(len + 2, b'R', |buf| {
|
||||
buf.write_raw(len + 2, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods {
|
||||
buf.put_slice(method.as_bytes());
|
||||
@@ -515,24 +607,19 @@ impl BeMessage<'_> {
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(11); // Continue SASL auth.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(12); // Send final SASL message.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
BeMessage::NoticeResponse(msg) => {
|
||||
@@ -564,15 +651,16 @@ impl BeMessage<'_> {
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
|
||||
buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I'));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::NegotiateProtocolVersion { version, options } => {
|
||||
let len: usize = options.iter().map(|o| o.len() + 1).sum();
|
||||
buf.write_raw(8 + len, b'v', |buf| {
|
||||
buf.write_raw(8 + len, BE_NEGOTIATE_MESSAGE.0, |buf| {
|
||||
buf.put_slice(version.as_bytes());
|
||||
buf.put_u32(options.len() as u32);
|
||||
for option in options {
|
||||
|
||||
@@ -2,7 +2,6 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
@@ -53,7 +52,6 @@ pub(crate) struct TcpMechanism {
|
||||
pub(crate) auth: AuthInfo,
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
pub(crate) user_info: ComputeUserInfo,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -73,11 +71,7 @@ impl ConnectMechanism for TcpMechanism {
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(
|
||||
node_info
|
||||
.connect(ctx, &self.auth, config, self.user_info.clone())
|
||||
.await,
|
||||
)
|
||||
permit.release_result(node_info.connect(ctx, &self.auth, config).await)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::sync::Arc;
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::RawCancelToken;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
@@ -18,7 +19,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cancellation::{self, CancelClosure, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
@@ -26,11 +28,11 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::pqproto::{CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PostgresError, PqFeStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
@@ -155,7 +157,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
@@ -253,7 +255,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
client: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
@@ -273,9 +275,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
@@ -307,7 +309,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
let hostname = mode.hostname(client.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
@@ -319,14 +321,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
&mut client,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
@@ -339,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return Err(stream
|
||||
return Err(client
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await)?;
|
||||
@@ -357,27 +359,40 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&auth::Backend::ControlPlane(cplane, creds.info),
|
||||
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
|
||||
let node = match res {
|
||||
let mut node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let cancel_closure =
|
||||
prepare_client_connection(&mut node, session.key(), &mut client, creds.info).await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
let compute = node.stream.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
@@ -386,40 +401,75 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
client,
|
||||
compute,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
pub(crate) async fn prepare_client_connection(
|
||||
node: &mut compute::PostgresConnection,
|
||||
key_data: &CancelKeyData,
|
||||
stream: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<CancelClosure, std::io::Error> {
|
||||
use zerocopy::{FromBytes, IntoBytes};
|
||||
|
||||
use crate::pqproto::{BE_KEY_MESSAGE, BE_READY_MESSAGE};
|
||||
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
loop {
|
||||
match node.stream.read_raw_be(1024).await {
|
||||
// parse backend keys, and substitute our own.
|
||||
Ok((tag @ BE_KEY_MESSAGE, msg)) => {
|
||||
stream.write_raw(8, tag, |b| b.extend_from_slice(key_data.as_bytes()));
|
||||
|
||||
let key_data = CancelKeyData::read_from_bytes(msg)
|
||||
.map_err(|_| std::io::Error::other("invalid msg len"))?;
|
||||
|
||||
process_id = (key_data.0.get() >> 32) as i32;
|
||||
secret_key = (key_data.0.get() & 0xffff_ffff) as i32;
|
||||
}
|
||||
// ready for query, we're done :)
|
||||
Ok((tag @ BE_READY_MESSAGE, msg)) => {
|
||||
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
|
||||
break;
|
||||
}
|
||||
// either a notice or a parameter status.
|
||||
Ok((tag, msg)) => {
|
||||
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
|
||||
}
|
||||
Err(PostgresError::Io(io)) => return Err(io),
|
||||
Err(PostgresError::Error(e)) => return Err(std::io::Error::other(e)),
|
||||
Err(_) => unreachable!("read_raw_be only returns IO or BackendError types"),
|
||||
}
|
||||
|
||||
if stream.write_buf_len() > 512 {
|
||||
stream.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &node.params {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
});
|
||||
}
|
||||
|
||||
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
Ok(CancelClosure::new(
|
||||
node.socket_addr,
|
||||
RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
node.hostname.clone(),
|
||||
user_info,
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use tokio::time;
|
||||
|
||||
use crate::compute;
|
||||
use crate::config::RetryConfig;
|
||||
use crate::stream::{BackendError, PostgresError};
|
||||
|
||||
pub(crate) trait CouldRetry {
|
||||
/// Returns true if the error could be retried
|
||||
@@ -96,10 +98,55 @@ impl ShouldRetryWakeCompute for postgres_client::Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for BackendError {
|
||||
fn could_retry(&self) -> bool {
|
||||
let (code, _message) = self.parse();
|
||||
matches!(
|
||||
code,
|
||||
crate::pqproto::CONNECTION_FAILURE
|
||||
| crate::pqproto::CONNECTION_EXCEPTION
|
||||
| crate::pqproto::CONNECTION_DOES_NOT_EXIST
|
||||
| crate::pqproto::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ShouldRetryWakeCompute for BackendError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
let (code, message) = self.parse();
|
||||
|
||||
// Here are errors that happens after the user successfully authenticated to the database.
|
||||
let non_retriable_pg_errors = matches!(
|
||||
code,
|
||||
crate::pqproto::TOO_MANY_CONNECTIONS
|
||||
| crate::pqproto::OUT_OF_MEMORY
|
||||
| crate::pqproto::SYNTAX_ERROR
|
||||
| crate::pqproto::T_R_SERIALIZATION_FAILURE
|
||||
| crate::pqproto::INVALID_CATALOG_NAME
|
||||
| crate::pqproto::INVALID_SCHEMA_NAME
|
||||
| crate::pqproto::INVALID_PARAMETER_VALUE,
|
||||
);
|
||||
if non_retriable_pg_errors {
|
||||
return false;
|
||||
}
|
||||
|
||||
// PGBouncer errors that should not trigger a wake_compute retry.
|
||||
if code == crate::pqproto::PROTOCOL_VIOLATION {
|
||||
// Source for the error message:
|
||||
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
|
||||
return message.contains_str("no more connections allowed (max_client_conn)");
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Error(err)) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Io(err)) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
|
||||
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
|
||||
compute::ConnectionError::TlsError(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
@@ -109,7 +156,12 @@ impl CouldRetry for compute::ConnectionError {
|
||||
impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Error(err)) => {
|
||||
err.should_retry_wake_compute()
|
||||
}
|
||||
compute::ConnectionError::Postgres(PostgresError::Io(_)) => true,
|
||||
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
|
||||
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
|
||||
// the cache entry was not checked for validity
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
@@ -122,7 +123,7 @@ fn generate_tls_config<'a>(
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
stream: &mut PqFeStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
Ok(())
|
||||
@@ -151,7 +152,7 @@ impl Scram {
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
stream: &mut PqFeStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
|
||||
.authenticate()
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use anyhow::Ok;
|
||||
|
||||
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
@@ -23,39 +19,12 @@ impl KeyPrefix {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result<KeyPrefix> {
|
||||
let (prefix, key_str) = key.split_once(':').ok_or_else(|| {
|
||||
anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"missing prefix"
|
||||
))
|
||||
})?;
|
||||
|
||||
match prefix {
|
||||
keyspace::CANCEL_PREFIX => {
|
||||
let id = u64::from_str_radix(key_str, 16)?;
|
||||
|
||||
Ok(KeyPrefix::Cancel(id_to_cancel_key(id)))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unknown prefix"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::pqproto::id_to_cancel_key;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -65,16 +34,4 @@ mod tests {
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_redis_key() {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
assert_eq!(ref_key, cancel_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
use redis::aio::ConnectionLike;
|
||||
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
|
||||
|
||||
@@ -35,14 +38,11 @@ impl RedisKVClient {
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.client
|
||||
.connect()
|
||||
.boxed()
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
|
||||
}
|
||||
|
||||
pub(crate) async fn query<T: FromRedisValue>(
|
||||
@@ -54,15 +54,25 @@ impl RedisKVClient {
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match q.query(&mut self.client).await {
|
||||
let e = match q.query(&mut self.client).await {
|
||||
Ok(t) => return Ok(t),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to run query: {e}");
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
tracing::error!("failed to run query: {e}");
|
||||
match e.retry_method() {
|
||||
redis::RetryMethod::Reconnect => {
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
}
|
||||
redis::RetryMethod::RetryImmediately => {}
|
||||
redis::RetryMethod::WaitAndRetry => {
|
||||
// somewhat arbitrary.
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
_ => Err(e)?,
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
Ok(q.query(&mut self.client).await?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use super::{Mechanism, Step};
|
||||
use crate::context::RequestContext;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::stream::PqStream;
|
||||
use crate::stream::PqFeStream;
|
||||
|
||||
/// SASL authentication outcome.
|
||||
/// It's much easier to match on those two variants
|
||||
@@ -22,7 +22,7 @@ pub(crate) enum Outcome<R> {
|
||||
|
||||
pub async fn authenticate<S, F, M>(
|
||||
ctx: &RequestContext,
|
||||
stream: &mut PqStream<S>,
|
||||
stream: &mut PqFeStream<S>,
|
||||
mechanism: F,
|
||||
) -> super::Result<Outcome<M::Output>>
|
||||
where
|
||||
|
||||
@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
use rustls::ServerConfig;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
///
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {:?}, got {:?}",
|
||||
tag as char, actual_tag as char,
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
error_kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ReportedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source.source()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReportedError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
self.error_kind
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
}
|
||||
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
|
||||
/// Return SNI hostname when it's available.
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Can't upgrade TLS stream")]
|
||||
pub enum StreamUpgradeError {
|
||||
#[error("Bad state reached: can't upgrade TLS stream")]
|
||||
AlreadyTls,
|
||||
|
||||
#[error("Can't upgrade stream: IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(
|
||||
self,
|
||||
cfg: Arc<ServerConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
|
||||
.accept(raw)
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
})?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
168
proxy/src/stream/mod.rs
Normal file
168
proxy/src/stream/mod.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
mod pq_backend;
|
||||
mod pq_frontend;
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
pub use pq_backend::{BackendError, PostgresError, PqBeStream};
|
||||
pub use pq_frontend::PqFeStream;
|
||||
use rustls::ServerConfig;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
error_kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ReportedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source.source()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReportedError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
self.error_kind
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
|
||||
/// Return SNI hostname when it's available.
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Can't upgrade TLS stream")]
|
||||
pub enum StreamUpgradeError {
|
||||
#[error("Bad state reached: can't upgrade TLS stream")]
|
||||
AlreadyTls,
|
||||
|
||||
#[error("Can't upgrade stream: IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(
|
||||
self,
|
||||
cfg: Arc<ServerConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
|
||||
.accept(raw)
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
})?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
165
proxy/src/stream/pq_backend.rs
Normal file
165
proxy/src/stream/pq_backend.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
//! Postgres connection from backend, proxy is the frontend.
|
||||
|
||||
use std::io;
|
||||
|
||||
use bytes::Bytes;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::pqproto::{
|
||||
AuthTag, BE_AUTH_MESSAGE, BE_ERR_MESSAGE, BeTag, ErrorCode, SQLSTATE_INTERNAL_ERROR,
|
||||
StartupMessageParams, WriteBuf, read_message,
|
||||
};
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
pub struct PqBeStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqBeStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper and write the first startup message.
|
||||
pub fn new(stream: S, params: &StartupMessageParams) -> Self {
|
||||
let mut write = WriteBuf::new();
|
||||
write.startup(params);
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqBeStream<S> {
|
||||
/// Read a raw postgres packet from the backend, which will respect the max length requested,
|
||||
/// as well as handling postgres error messages.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_raw_be(&mut self, max: u32) -> Result<(BeTag, &mut [u8]), PostgresError> {
|
||||
let (tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
match BeTag(tag) {
|
||||
BE_ERR_MESSAGE => Err(PostgresError::Error(BackendError {
|
||||
data: msg.to_vec().into(),
|
||||
})),
|
||||
tag => Ok((tag, msg)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_be_expect(
|
||||
&mut self,
|
||||
tag: BeTag,
|
||||
max: u32,
|
||||
) -> Result<&mut [u8], PostgresError> {
|
||||
let (actual_tag, msg) = self.read_raw_be(max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(PostgresError::Unexpected(UnexpectedMessage {
|
||||
expected: tag,
|
||||
tag: actual_tag,
|
||||
data: msg.to_vec().into(),
|
||||
}));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres backend auth message.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_auth_message(&mut self) -> Result<(AuthTag, &mut [u8]), PostgresError> {
|
||||
const MAX_AUTH_LENGTH: u32 = 512;
|
||||
|
||||
self.read_raw_be_expect(BE_AUTH_MESSAGE, MAX_AUTH_LENGTH)
|
||||
.await?
|
||||
.split_first_chunk_mut()
|
||||
.map(|(tag, msg)| (AuthTag(i32::from_be_bytes(*tag)), msg))
|
||||
.ok_or(PostgresError::InvalidAuthMessage)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqBeStream<S> {
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PostgresError {
|
||||
#[error("postgres responded with error {0}")]
|
||||
Error(#[from] BackendError),
|
||||
#[error("postgres responded with an unexpected message: {0}")]
|
||||
Unexpected(#[from] UnexpectedMessage),
|
||||
#[error("postgres responded with an invalid authentication message")]
|
||||
InvalidAuthMessage,
|
||||
#[error("IO error from compute: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("expected {expected}, got {tag} with data {data:?}")]
|
||||
pub struct UnexpectedMessage {
|
||||
expected: BeTag,
|
||||
tag: BeTag,
|
||||
data: Bytes,
|
||||
}
|
||||
|
||||
pub struct BackendError {
|
||||
data: Bytes,
|
||||
}
|
||||
|
||||
impl BackendError {
|
||||
pub fn parse(&self) -> (ErrorCode, &[u8]) {
|
||||
let mut code = &[] as &[u8];
|
||||
let mut message = &[] as &[u8];
|
||||
|
||||
for param in self.data.split(|b| *b == 0) {
|
||||
match param {
|
||||
[b'M', rest @ ..] => message = rest,
|
||||
[b'C', rest @ ..] => code = rest,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let code = code.try_into().map_or(SQLSTATE_INTERNAL_ERROR, ErrorCode);
|
||||
|
||||
(code, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BackendError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", &self.data)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for BackendError {}
|
||||
197
proxy/src/stream/pq_frontend.rs
Normal file
197
proxy/src/stream/pq_frontend.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
//! Postgres connection from frontend, proxy is the backend.
|
||||
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::error::{ErrorKind, UserFacingError};
|
||||
use crate::pqproto::{
|
||||
BeMessage, BeTag, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR,
|
||||
WriteBuf, read_message, read_startup,
|
||||
};
|
||||
use crate::stream::ReportedError;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
pub struct PqFeStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqFeStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_buf_len(&self) -> usize {
|
||||
self.write.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqFeStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqFeStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: FeTag, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
let actual_tag = FeTag(actual_tag);
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {tag}, got {actual_tag}",
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqFeStream<S> {
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: BeTag, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag.0, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
}
|
||||
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user