Compare commits

...

15 Commits

Author SHA1 Message Date
Conrad Ludgate
fb02e54843 re-impl auth startup to compute 2025-06-10 22:39:22 -07:00
Conrad Ludgate
41a5b8524a remove user_info from connect 2025-06-10 22:21:23 -07:00
Conrad Ludgate
6da7b87a32 move PqStream to pq_frontend.rs and rename to PgFeStream 2025-06-10 22:13:43 -07:00
Conrad Ludgate
72b1c573b1 add some more typesafety to pqproto and move stream.rs to a folder 2025-06-10 22:13:43 -07:00
Conrad Ludgate
b509982bbf add another debug and overflow comment 2025-06-10 13:51:26 -07:00
Conrad Ludgate
a78a52acb5 add in timeout for cancellation 2025-06-10 13:44:30 -07:00
Conrad Ludgate
3370e8cb00 optimise future sizes for cancel maintenance 2025-06-10 13:40:39 -07:00
Conrad Ludgate
f37a558280 move the cancel-on-shutdown handling to the cancel session maintenance task 2025-06-10 13:40:27 -07:00
Conrad Ludgate
744011437a create batch processing struct 2025-06-10 13:37:40 -07:00
Conrad Ludgate
a10d26a083 no explicit remove, only passive ttl 2025-06-10 13:36:11 -07:00
Conrad Ludgate
aece520365 remove replies for store/remove ops 2025-06-10 13:18:15 -07:00
Conrad Ludgate
9017811d61 remove dead code for redis keys 2025-06-10 13:18:15 -07:00
Conrad Ludgate
551a33aa04 use hget instead of hgetall 2025-06-10 13:18:15 -07:00
Conrad Ludgate
95216ae6ec box the connect future and respect the redis retry methods on err 2025-06-10 13:18:15 -07:00
Conrad Ludgate
a3a10d1839 split CancelToken into RawCancelToken for smaller sizes and better typesafety 2025-06-10 13:18:15 -07:00
35 changed files with 1395 additions and 949 deletions

View File

@@ -1,5 +1,3 @@
use std::io;
use tokio::net::TcpStream;
use crate::client::SocketConfig;
@@ -8,7 +6,7 @@ use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
config: SocketConfig,
ssl_mode: SslMode,
tls: T,
process_id: i32,
@@ -17,16 +15,6 @@ pub(crate) async fn cancel_query<T>(
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)));
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};

View File

@@ -9,9 +9,16 @@ use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub socket_config: SocketConfig,
pub raw: RawCancelToken,
}
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
@@ -36,14 +43,16 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
self.raw.ssl_mode,
tls,
self.process_id,
self.secret_key,
self.raw.process_id,
self.raw.secret_key,
)
.await
}
}
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>

View File

@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
@@ -331,10 +332,12 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
},
}
}

View File

@@ -3,7 +3,7 @@
use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::CancelToken;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;

View File

@@ -14,7 +14,7 @@ pub(crate) mod private {
/// Channel binding information returned from a TLS handshake.
pub struct ChannelBinding {
pub(crate) tls_server_end_point: Option<Vec<u8>>,
pub tls_server_end_point: Option<Vec<u8>>,
}
impl ChannelBinding {

View File

@@ -7,13 +7,13 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::stream::{PqStream, Stream};
use crate::stream::{PqFeStream, Stream};
use crate::{compute, sasl};
pub(super) async fn authenticate(
ctx: &RequestContext,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {

View File

@@ -17,7 +17,7 @@ use crate::error::{ReportableError, UserFacingError};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::stream::PqStream;
use crate::stream::PqFeStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -96,7 +96,7 @@ impl ConsoleRedirectBackend {
&self,
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
@@ -122,7 +122,7 @@ async fn authenticate(
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);

View File

@@ -17,7 +17,7 @@ use crate::stream::{self, Stream};
pub(crate) async fn authenticate_cleartext(
ctx: &RequestContext,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
@@ -61,7 +61,7 @@ pub(crate) async fn authenticate_cleartext(
pub(crate) async fn password_hack_no_authentication(
ctx: &RequestContext,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<(ComputeUserInfo, Vec<u8>)> {
debug!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);

View File

@@ -201,7 +201,7 @@ async fn auth_quirks(
ctx: &RequestContext,
api: &impl control_plane::ControlPlaneApi,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -267,7 +267,7 @@ async fn authenticate_with_secret(
ctx: &RequestContext,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -318,7 +318,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
pub(crate) async fn authenticate(
self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -446,7 +446,7 @@ mod tests {
use crate::rate_limiter::EndpointRateLimiter;
use crate::scram::ServerSecret;
use crate::scram::threadpool::ThreadPool;
use crate::stream::{PqStream, Stream};
use crate::stream::{PqFeStream, Stream};
struct Auth {
ips: Vec<IpPattern>,
@@ -522,7 +522,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -604,7 +604,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -658,7 +658,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {

View File

@@ -15,7 +15,7 @@ use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::stream::{PqFeStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
@@ -53,7 +53,7 @@ pub(crate) struct CleartextPassword {
#[must_use]
pub(crate) struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
stream: &'a mut PqFeStream<Stream<S>>,
/// State might contain ancillary data.
state: State,
tls_server_end_point: TlsServerEndPoint,
@@ -62,7 +62,7 @@ pub(crate) struct AuthFlow<'a, S, State> {
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
pub(crate) fn new(stream: &'a mut PqFeStream<Stream<S>>, method: M) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {

145
proxy/src/batch.rs Normal file
View File

@@ -0,0 +1,145 @@
//! Batch processing system based on intrusive linked lists.
//!
//! Enqueuing a batch job requires no allocations, with
//! direct support for cancelling jobs early.
use std::collections::BTreeMap;
use std::pin::pin;
use std::sync::Mutex;
use futures::future::Either;
use scopeguard::ScopeGuard;
use tokio::sync::oneshot::error::TryRecvError;
use crate::ext::LockExt;
pub trait QueueProcessing: Send + 'static {
type Req: Send + 'static;
type Res: Send;
/// Get the desired batch size.
fn batch_size(&self, queue_size: usize) -> usize;
/// This applies a full batch of events.
/// Must respond with a full batch of replies.
///
/// If this apply can error, it's expected that errors be forwarded to each Self::Res.
///
/// Batching does not need to happen atomically.
fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
}
pub struct BatchQueue<P: QueueProcessing> {
processor: tokio::sync::Mutex<P>,
inner: Mutex<BatchQueueInner<P>>,
}
struct BatchJob<P: QueueProcessing> {
req: P::Req,
res: tokio::sync::oneshot::Sender<P::Res>,
}
impl<P: QueueProcessing> BatchQueue<P> {
pub fn new(p: P) -> Self {
Self {
processor: tokio::sync::Mutex::new(p),
inner: Mutex::new(BatchQueueInner {
version: 0,
queue: BTreeMap::new(),
}),
}
}
pub async fn call(&self, req: P::Req) -> P::Res {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let guard = scopeguard::guard(id, move |id| {
if self.inner.lock_propagate_poison().queue.remove(&id).is_some() {
tracing::debug!("batched task cancelled before completion");
}
});
let resp = loop {
// try become the leader, or try wait for success.
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await
{
// we got the resp.
Either::Left((resp, _)) => break resp.ok(),
// we are the leader.
Either::Right((p, rx_)) => {
rx = rx_;
p
}
};
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
// apply a batch.
let values = processor.apply(reqs).await;
// send response values.
for (tx, value) in std::iter::zip(resps, values) {
// sender hung up but that's fine.
drop(tx.send(value));
}
match rx.try_recv() {
Ok(resp) => break Some(resp),
Err(TryRecvError::Closed) => break None,
// edge case - there was a race condition where
// we became the leader but were not in the batch.
//
// Example:
// thread 1: register job id=1
// thread 2: register job id=2
// thread 2: processor.lock().await
// thread 1: processor.lock().await
// thread 2: becomes leader, batch_size=1, jobs=[1].
Err(TryRecvError::Empty) => {}
}
};
// already removed.
ScopeGuard::into_inner(guard);
resp.expect("no response found. batch processer should not panic")
}
}
struct BatchQueueInner<P: QueueProcessing> {
version: u64,
queue: BTreeMap<u64, BatchJob<P>>,
}
impl<P: QueueProcessing> BatchQueueInner<P> {
fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
let (tx, rx) = tokio::sync::oneshot::channel();
let id = self.version;
// Overflow concern:
// This is a u64, and we might enqueue 2^16 tasks per second.
// This gives us 2^48 seconds (9 million years).
// Even if this does overflow, it will not break, but some
// jobs with the higher version might never get prioritised.
self.version += 1;
self.queue.insert(id, BatchJob { req, res: tx });
(id, rx)
}
fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size);
while reqs.len() < batch_size {
let Some((_, job)) = self.queue.pop_first() else {
break;
};
reqs.push(job.req);
resps.push(job.res);
}
(reqs, resps)
}
}

View File

@@ -201,7 +201,7 @@ pub async fn run() -> anyhow::Result<()> {
auth_backend,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
Arc::new(CancellationHandler::new(&config.connect_to_compute)),
endpoint_rate_limiter,
);

View File

@@ -29,7 +29,7 @@ use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::stream::{PqStream, Stream};
use crate::stream::{PqFeStream, Stream};
use crate::util::run_until_cancelled;
project_git_version!(GIT_VERSION);
@@ -262,7 +262,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
let (mut stream, msg) = PqFeStream::parse_startup(Stream::from_raw(raw_stream)).await?;
match msg {
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;

View File

@@ -21,7 +21,8 @@ use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -390,13 +391,7 @@ pub async fn run() -> anyhow::Result<()> {
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
let cancellation_handler = Arc::new(CancellationHandler::new(
&config.connect_to_compute,
Some(tx_cancel),
));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
@@ -523,14 +518,10 @@ pub async fn run() -> anyhow::Result<()> {
if let Some(mut redis_kv_client) = redis_kv_client {
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,

View File

@@ -1,19 +1,23 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use anyhow::{Context, anyhow};
use anyhow::anyhow;
use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, warn};
use tokio::time::timeout;
use tracing::{debug, error, info};
use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo;
use crate::batch::{BatchQueue, QueueProcessing};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
@@ -27,46 +31,36 @@ use crate::redis::kv_ops::RedisKVClient;
type IpSubnetKey = IpNet;
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
// Message types for sending through mpsc channel
pub enum CancelKeyOp {
StoreCancelKey {
key: String,
field: String,
value: String,
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
_guard: CancelChannelSizeGuard<'static>,
expire: i64, // TTL for key
key: CancelKeyData,
value: Box<str>,
expire: std::time::Duration,
},
GetCancelData {
key: String,
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
key: String,
field: String,
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
_guard: CancelChannelSizeGuard<'static>,
key: CancelKeyData,
},
}
pub struct Pipeline {
inner: redis::Pipeline,
replies: Vec<CancelReplyOp>,
replies: usize,
}
impl Pipeline {
fn with_capacity(n: usize) -> Self {
Self {
inner: redis::Pipeline::with_capacity(n),
replies: Vec::with_capacity(n),
replies: 0,
}
}
async fn execute(&mut self, client: &mut RedisKVClient) {
let responses = self.replies.len();
async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
let responses = self.replies;
let batch_size = self.inner.len();
match client.query(&self.inner).await {
@@ -76,176 +70,73 @@ impl Pipeline {
batch_size,
responses, "successfully completed cancellation jobs",
);
for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
reply.send_value(value);
}
values.into_iter().map(Ok).collect()
}
Ok(value) => {
error!(batch_size, ?value, "unexpected redis return value");
for reply in self.replies.drain(..) {
reply.send_err(anyhow!("incorrect response type from redis"));
}
std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
.take(responses)
.collect()
}
Err(err) => {
for reply in self.replies.drain(..) {
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
}
std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
.take(responses)
.collect()
}
}
self.inner.clear();
self.replies.clear();
}
fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
fn add_command_with_reply(&mut self, cmd: Cmd) {
self.inner.add_command(cmd);
self.replies.push(reply);
self.replies += 1;
}
fn add_command_no_reply(&mut self, cmd: Cmd) {
self.inner.add_command(cmd).ignore();
}
fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
match reply {
Some(reply) => self.add_command_with_reply(cmd, reply),
None => self.add_command_no_reply(cmd),
}
}
}
impl CancelKeyOp {
fn register(self, pipe: &mut Pipeline) {
fn register(&self, pipe: &mut Pipeline) {
#[allow(clippy::used_underscore_binding)]
match self {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
expire,
} => {
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hset(&key, field, value), reply);
pipe.add_command_no_reply(Cmd::expire(key, expire));
CancelKeyOp::StoreCancelKey { key, value, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
}
CancelKeyOp::GetCancelData {
key,
resp_tx,
_guard,
} => {
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
}
CancelKeyOp::RemoveCancelKey {
key,
field,
resp_tx,
_guard,
} => {
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hdel(key, field), reply);
CancelKeyOp::GetCancelData { key } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command_with_reply(Cmd::hget(key, "data"));
}
}
}
}
// Message types for sending through mpsc channel
pub enum CancelReplyOp {
StoreCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
pub struct CancellationProcessor {
pub client: RedisKVClient,
pub batch_size: usize,
}
impl CancelReplyOp {
fn send_err(self, e: anyhow::Error) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
}
impl QueueProcessing for CancellationProcessor {
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
type Res = anyhow::Result<redis::Value>;
fn batch_size(&self, _queue_size: usize) -> usize {
self.batch_size
}
fn send_value(self, v: redis::Value) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
}
}
}
// Running as a separate task to accept messages through the rx channel
pub async fn handle_cancel_messages(
client: &mut RedisKVClient,
mut rx: mpsc::Receiver<CancelKeyOp>,
batch_size: usize,
) -> anyhow::Result<()> {
let mut batch = Vec::with_capacity(batch_size);
let mut pipeline = Pipeline::with_capacity(batch_size);
loop {
if rx.recv_many(&mut batch, batch_size).await == 0 {
warn!("shutting down cancellation queue");
break Ok(());
}
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
let mut pipeline = Pipeline::with_capacity(batch.len());
let batch_size = batch.len();
debug!(batch_size, "running cancellation jobs");
for msg in batch.drain(..) {
msg.register(&mut pipeline);
for (_, op) in &batch {
op.register(&mut pipeline);
}
pipeline.execute(client).await;
pipeline.execute(&mut self.client).await
}
}
@@ -256,7 +147,7 @@ pub struct CancellationHandler {
compute_config: &'static ComputeConfig,
// rate limiter of cancellation requests
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
}
#[derive(Debug, Error)]
@@ -296,13 +187,10 @@ impl ReportableError for CancelError {
}
impl CancellationHandler {
pub fn new(
compute_config: &'static ComputeConfig,
tx: Option<mpsc::Sender<CancelKeyOp>>,
) -> Self {
pub fn new(compute_config: &'static ComputeConfig) -> Self {
Self {
compute_config,
tx,
tx: OnceLock::new(),
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
@@ -312,7 +200,14 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
self.tx
.set(queue)
.map_err(|_| {})
.expect("cancellation queue should be registered once");
}
pub(crate) fn get_key(self: Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -322,14 +217,10 @@ impl CancellationHandler {
let key: CancelKeyData = rand::random();
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
let redis_key = prefix_key.build_redis_key();
debug!("registered new query cancellation key {key}");
Session {
key,
redis_key,
cancellation_handler: Arc::clone(self),
cancellation_handler: self,
}
}
@@ -337,62 +228,43 @@ impl CancellationHandler {
&self,
key: CancelKeyData,
) -> Result<Option<CancelClosure>, CancelError> {
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
let redis_key = prefix_key.build_redis_key();
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGet);
let op = CancelKeyOp::GetCancelData { key };
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let op = CancelKeyOp::GetCancelData {
key: redis_key,
resp_tx,
_guard: Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGetAll),
};
let Some(tx) = &self.tx else {
let Some(tx) = self.tx.get() else {
tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError);
};
tx.try_send(op)
const TIMEOUT: Duration = Duration::from_secs(5);
let result = timeout(TIMEOUT, tx.call((guard, op)))
.await
.map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit
})?
.map_err(|e| {
tracing::warn!("failed to send GetCancelData for {key}: {e}");
})
.map_err(|()| CancelError::InternalError)?;
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
let result = resp_rx.await.map_err(|e| {
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
let cancel_state_str: Option<String> = match result {
Ok(mut state) => {
if state.len() == 1 {
Some(state.remove(0).1)
} else {
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
return Err(CancelError::InternalError);
}
}
Err(e) => {
tracing::warn!("failed to receive cancel state from redis: {e}");
return Err(CancelError::InternalError);
}
};
let cancel_closure: CancelClosure =
serde_json::from_str(&cancel_state_str).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
let cancel_state: Option<CancelClosure> = match cancel_state_str {
Some(state) => {
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Some(cancel_closure)
}
None => None,
};
Ok(cancel_state)
Ok(Some(cancel_closure))
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query.
@@ -467,10 +339,10 @@ impl CancellationHandler {
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}
@@ -478,7 +350,7 @@ pub struct CancelClosure {
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
@@ -491,7 +363,7 @@ impl CancelClosure {
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
self,
&self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
@@ -512,7 +384,6 @@ impl CancelClosure {
pub(crate) struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
redis_key: String,
cancellation_handler: Arc<CancellationHandler>,
}
@@ -521,60 +392,66 @@ impl Session {
&self.key
}
// Send the store key op to the cancellation handler and set TTL for the key
pub(crate) fn write_cancel_key(
/// Ensure the cancel key is continously refreshed,
/// but stop when the channel is dropped.
pub(crate) async fn maintain_cancel_key(
&self,
cancel_closure: CancelClosure,
) -> Result<(), CancelError> {
let Some(tx) = &self.cancellation_handler.tx else {
tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError);
};
session_id: uuid::Uuid,
cancel: tokio::sync::oneshot::Receiver<Infallible>,
cancel_closure: &CancelClosure,
compute_config: &ComputeConfig,
) {
futures::future::select(
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
cancel,
)
.await;
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
tracing::warn!("failed to serialize cancel closure: {e}");
CancelError::InternalError
})?;
let op = CancelKeyOp::StoreCancelKey {
key: self.redis_key.clone(),
field: "data".to_string(),
value: closure_json,
resp_tx: None,
_guard: Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HSet),
expire: CANCEL_KEY_TTL,
};
let _ = tx.try_send(op).map_err(|e| {
let key = self.key;
tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
});
Ok(())
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
}
pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
let Some(tx) = &self.cancellation_handler.tx else {
// Ensure the cancel key is continously refreshed.
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
let Some(tx) = self.cancellation_handler.tx.get() else {
tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError);
// don't exit, as we only want to exit if cancelled externally.
std::future::pending().await
};
let op = CancelKeyOp::RemoveCancelKey {
key: self.redis_key.clone(),
field: "data".to_string(),
resp_tx: None,
_guard: Metrics::get()
let closure_json = serde_json::to_string(&cancel_closure)
.expect("serialising to json string should not fail")
.into_boxed_str();
loop {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HDel),
};
.guard(RedisMsgKind::HSet);
let op = CancelKeyOp::StoreCancelKey {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
};
let _ = tx.try_send(op).map_err(|e| {
let key = self.key;
tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
});
Ok(())
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registering cancellation key"
);
if tx.call((guard, op)).await.is_ok() {
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
}
}
}
}

View File

@@ -0,0 +1,146 @@
use bytes::BufMut;
use postgres_client::tls::{ChannelBinding, TlsStream};
use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use tokio::io::{AsyncRead, AsyncWrite};
use super::{Auth, MaybeRustlsStream};
use crate::compute::RustlsStream;
use crate::pqproto::{
AUTH_OK, AUTH_SASL, AUTH_SASL_CONT, AUTH_SASL_FINAL, FE_PASSWORD_MESSAGE, StartupMessageParams,
};
use crate::stream::{PostgresError, PqBeStream};
pub async fn authenticate<S>(
stream: MaybeRustlsStream<S>,
auth: Option<&Auth>,
params: &StartupMessageParams,
) -> Result<PqBeStream<MaybeRustlsStream<S>>, PostgresError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
RustlsStream<S>: TlsStream + Unpin,
{
let mut stream = PqBeStream::new(stream, params);
stream.flush().await?;
let channel_binding = stream.get_ref().channel_binding();
// TODO: rather than checking for SASL, maybe we can just assume it.
// With SCRAM_SHA_256 if we're not using TLS,
// and SCRAM_SHA_256_PLUS if we are using TLS.
let (channel_binding, mechanism) = match stream.read_auth_message().await? {
(AUTH_OK, _) => return Ok(stream),
(AUTH_SASL, mechanisms) => {
let mut has_scram = false;
let mut has_scram_plus = false;
for mechanism in mechanisms.split(|&b| b == b'\0') {
match mechanism {
b"SCRAM-SHA-256" => has_scram = true,
b"SCRAM-SHA-256-PLUS" => has_scram_plus = true,
_ => {}
}
}
match (channel_binding, has_scram, has_scram_plus) {
(cb, true, false) => {
if cb.tls_server_end_point.is_some() {
// I don't think this can happen in our setup, but I would like to monitor it.
tracing::warn!(
"TLS is enabled, but compute doesn't support SCRAM-SHA-256-PLUS."
);
}
(sasl::ChannelBinding::unrequested(), SCRAM_SHA_256)
}
(
ChannelBinding {
tls_server_end_point: None,
},
true,
_,
) => (sasl::ChannelBinding::unsupported(), SCRAM_SHA_256),
(
ChannelBinding {
tls_server_end_point: Some(h),
},
_,
true,
) => (
sasl::ChannelBinding::tls_server_end_point(h),
SCRAM_SHA_256_PLUS,
),
(_, false, _) => {
tracing::error!(
"compute responded with unsupported auth mechanisms: {}",
String::from_utf8_lossy(mechanisms)
);
return Err(PostgresError::InvalidAuthMessage);
}
}
}
(tag, msg) => {
tracing::error!(
"compute responded with unexpected auth message with tag[{tag}]: {}",
String::from_utf8_lossy(msg)
);
return Err(PostgresError::InvalidAuthMessage);
}
};
let mut scram = match auth {
// We only touch passwords when it comes to console-redirect.
Some(Auth::Password(pw)) => sasl::ScramSha256::new(pw, channel_binding),
Some(Auth::Scram(keys)) => sasl::ScramSha256::new_with_keys(**keys, channel_binding),
None => {
// local_proxy does not set credentials, since it relies on trust and expects an OK message above
tracing::error!("compute requested SASL auth, but there are no credentials available",);
return Err(PostgresError::InvalidAuthMessage);
}
};
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
buf.put_slice(mechanism.as_bytes());
buf.put_u8(b'\0');
let data = scram.message();
buf.put_u32(data.len() as u32);
buf.put_slice(data);
});
stream.flush().await?;
loop {
// wait for SASLContinue or SASLFinal.
match stream.read_auth_message().await? {
(AUTH_SASL_CONT, data) => scram.update(data).await?,
(AUTH_SASL_FINAL, data) => {
scram.finish(data)?;
break;
}
(tag, msg) => {
tracing::error!(
"compute responded with unexpected auth message with tag[{tag}]: {}",
String::from_utf8_lossy(msg)
);
return Err(PostgresError::InvalidAuthMessage);
}
}
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
buf.put_slice(scram.message());
});
stream.flush().await?;
}
match stream.read_auth_message().await? {
(AUTH_OK, _) => {}
(tag, msg) => {
tracing::error!(
"compute responded with unexpected auth message with tag[{tag}]: {}",
String::from_utf8_lossy(msg)
);
return Err(PostgresError::InvalidAuthMessage);
}
}
Ok(stream)
}

View File

@@ -1,3 +1,4 @@
mod authenticate;
mod tls;
use std::fmt::Debug;
@@ -9,15 +10,12 @@ use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
@@ -28,6 +26,7 @@ use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::stream::{PostgresError, PqBeStream};
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -37,7 +36,7 @@ pub(crate) enum ConnectionError {
/// This error doesn't seem to reveal any secrets; for instance,
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
#[error("{COULD_NOT_CONNECT}: {0}")]
Postgres(#[from] postgres_client::Error),
Postgres(#[from] PostgresError),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] TlsError),
@@ -54,20 +53,21 @@ impl UserFacingError for ConnectionError {
match self {
// This helps us drop irrelevant library-specific prefixes.
// TODO: propagate severity level and other parameters.
ConnectionError::Postgres(err) => match err.as_db_error() {
Some(err) => {
let msg = err.message();
ConnectionError::Postgres(PostgresError::Error(err)) => {
let (_code, msg) = err.parse();
let msg = String::from_utf8_lossy(msg);
if msg.starts_with("unsupported startup parameter: ")
|| msg.starts_with("unsupported startup parameter in options: ")
{
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
} else {
msg.to_owned()
}
if msg.starts_with("unsupported startup parameter: ")
|| msg.starts_with("unsupported startup parameter in options: ")
{
format!(
"{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
)
} else {
msg.into_owned()
}
None => err.to_string(),
},
}
ConnectionError::Postgres(err) => err.to_string(),
ConnectionError::WakeComputeError(err) => err.to_string_client(),
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
@@ -80,10 +80,12 @@ impl UserFacingError for ConnectionError {
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::Postgres(PostgresError::Io(_)) => crate::error::ErrorKind::Compute,
ConnectionError::Postgres(
PostgresError::Error(_)
| PostgresError::InvalidAuthMessage
| PostgresError::Unexpected(_),
) => crate::error::ErrorKind::Postgres,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -162,18 +164,6 @@ impl ConnectInfo {
}
impl AuthInfo {
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
match &self.auth {
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
Some(Auth::Password(pw)) => config.password(pw),
None => &mut config,
};
for (k, v) in self.server_params.iter() {
config.set_param(k, v);
}
config
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(
&mut self,
@@ -213,7 +203,7 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
) -> Result<(SocketAddr, MaybeRustlsStream<TcpStream>), TlsError> {
let timeout = config.timeout;
// wrap TcpStream::connect with timeout
@@ -265,21 +255,19 @@ impl ConnectInfo {
}
}
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub type RustlsStream<S> = <ComputeConfig as MakeTlsConnect<S>>::Stream;
pub type MaybeRustlsStream<S> = MaybeTlsStream<S, RustlsStream<S>>;
pub(crate) struct PostgresConnection {
pub struct PostgresConnection {
/// Socket connected to a compute node.
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// PostgreSQL connection parameters.
pub(crate) params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub(crate) cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
/// Notices received from compute after authenticating
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
pub stream: PqBeStream<MaybeRustlsStream<TcpStream>>,
_guage: NumDbConnectionsGuard<'static>,
pub socket_addr: SocketAddr,
pub hostname: String,
pub ssl_mode: SslMode,
pub aux: MetricsAuxInfo,
pub guage: NumDbConnectionsGuard<'static>,
}
impl ConnectInfo {
@@ -287,31 +275,18 @@ impl ConnectInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
aux: &MetricsAuxInfo,
auth: &AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<PostgresConnection, ConnectionError> {
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
// we setup SSL early in `ConnectInfo::connect_raw`.
tmp_config.ssl_mode(SslMode::Disable);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let connection = tmp_config.connect_raw(stream, NoTls).await?;
let stream =
authenticate::authenticate(stream, auth.auth.as_ref(), &auth.server_params).await?;
drop(pause);
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// tracing::Span::current().record("pid", tracing::field::display(process_id));
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
let MaybeTlsStream::Raw(stream) = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
@@ -323,27 +298,13 @@ impl ConnectInfo {
ctx.get_testodrome_id().unwrap_or_default(),
);
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.ssl_mode,
process_id,
secret_key,
},
self.host.to_string(),
user_info,
);
let connection = PostgresConnection {
stream,
params: parameters,
delayed_notice,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
socket_addr,
hostname: self.host.to_string(),
ssl_mode: self.ssl_mode,
aux: aux.clone(),
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
};
Ok(connection)

View File

@@ -120,7 +120,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
@@ -177,7 +177,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
@@ -210,18 +210,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, mut auth_info, user_info) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.authenticate(ctx, &config.authentication_config, &mut client)
.await
{
Ok(auth_result) => auth_result,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
auth_info.set_startup_params(&params, true);
let node = connect_to_compute(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
auth: auth_info,
locks: &config.connect_compute_locks,
},
@@ -229,25 +228,41 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.or_else(|e| async { Err(client.throw_error(e, Some(ctx)).await) })
.await?;
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
let session = cancellation_handler.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
let cancel_closure =
prepare_client_connection(&mut node, session.key(), &mut client, user_info).await?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&cancel_closure,
&config.connect_to_compute,
)
.await;
});
let client = client.flush_and_into_inner().await?;
let compute = node.stream.flush_and_into_inner().await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
client,
compute,
aux: node.aux,
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -78,11 +78,8 @@ impl NodeInfo {
ctx: &RequestContext,
auth: &compute::AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.conn_info
.connect(ctx, self.aux.clone(), auth, config, user_info)
.await
self.conn_info.connect(ctx, &self.aux, auth, config).await
}
}

View File

@@ -75,6 +75,7 @@
pub mod binary;
mod auth;
mod batch;
mod cache;
mod cancellation;
mod compute;

View File

@@ -12,7 +12,7 @@ use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::stream::{PqFeStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
#[derive(Error, Debug)]
@@ -49,7 +49,7 @@ impl ReportableError for HandshakeError {
}
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Startup(PqFeStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
}
@@ -70,7 +70,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
let (mut stream, mut msg) = PqFeStream::parse_startup(Stream::from_raw(stream)).await?;
loop {
match msg {
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
@@ -152,7 +152,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
tls: tls_stream,
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;
(stream, msg) = PqFeStream::parse_startup(tls).await?;
} else {
if direct.is_some() {
// client sent us a ClientHello already, we can't do anything with it.

View File

@@ -1,15 +1,18 @@
use futures::FutureExt;
use std::convert::Infallible;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tracing::debug;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::compute::MaybeRustlsStream;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::metrics::{
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
NumDbConnectionsGuard,
};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
@@ -64,40 +67,19 @@ pub(crate) async fn proxy_pass(
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) compute: MaybeRustlsStream<TcpStream>,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
}
}

View File

@@ -11,11 +11,86 @@ use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct ErrorCode(pub [u8; 5]);
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct FeTag(pub u8);
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct BeTag(pub u8);
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct AuthTag(pub i32);
pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p');
pub const BE_AUTH_MESSAGE: BeTag = BeTag(b'R');
pub const BE_ERR_MESSAGE: BeTag = BeTag(b'E');
pub const BE_KEY_MESSAGE: BeTag = BeTag(b'K');
pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z');
pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v');
pub const AUTH_OK: AuthTag = AuthTag(0);
pub const AUTH_SASL: AuthTag = AuthTag(10);
pub const AUTH_SASL_CONT: AuthTag = AuthTag(11);
pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12);
pub const SQLSTATE_INTERNAL_ERROR: ErrorCode = ErrorCode(*b"XX000");
pub const CONNECTION_EXCEPTION: ErrorCode = ErrorCode(*b"08000");
pub const CONNECTION_DOES_NOT_EXIST: ErrorCode = ErrorCode(*b"08003");
pub const CONNECTION_FAILURE: ErrorCode = ErrorCode(*b"08006");
pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: ErrorCode = ErrorCode(*b"08001");
pub const PROTOCOL_VIOLATION: ErrorCode = ErrorCode(*b"08P01");
pub const INVALID_PARAMETER_VALUE: ErrorCode = ErrorCode(*b"22023");
pub const INVALID_CATALOG_NAME: ErrorCode = ErrorCode(*b"3D000");
pub const INVALID_SCHEMA_NAME: ErrorCode = ErrorCode(*b"3F000");
pub const T_R_SERIALIZATION_FAILURE: ErrorCode = ErrorCode(*b"40001");
pub const SYNTAX_ERROR: ErrorCode = ErrorCode(*b"42601");
pub const OUT_OF_MEMORY: ErrorCode = ErrorCode(*b"53200");
pub const TOO_MANY_CONNECTIONS: ErrorCode = ErrorCode(*b"53300");
impl fmt::Display for AuthTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
0 => f.write_str("Ok"),
2 => f.write_str("KerberosV5"),
3 => f.write_str("CleartextPassword"),
5 => f.write_str("MD5Password"),
7 => f.write_str("GSS"),
8 => f.write_str("GSSContinue"),
9 => f.write_str("SSPI"),
10 => f.write_str("SASL"),
11 => f.write_str("SASLContinue"),
12 => f.write_str("SASLFinal"),
x => write!(f, "{x}"),
}
}
}
impl fmt::Display for BeTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
BE_AUTH_MESSAGE => f.write_str("Authentication"),
BE_KEY_MESSAGE => f.write_str("BackendKeyData"),
BE_ERR_MESSAGE => f.write_str("ErrorResponse"),
BE_READY_MESSAGE => f.write_str("ReadyForQuery"),
BE_NEGOTIATE_MESSAGE => f.write_str("NegotiateProtocolVersion"),
BeTag(b'S') => f.write_str("ParameterStatus"),
BeTag(b'N') => f.write_str("NoticeMessage"),
BeTag(x) => write!(f, "{:?}", char::from(x)),
}
}
}
impl fmt::Display for FeTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
FE_PASSWORD_MESSAGE => f.write_str("Password"),
FeTag(x) => write!(f, "{:?}", char::from(x)),
}
}
}
/// The protocol version number.
///
@@ -280,6 +355,10 @@ impl WriteBuf {
Self(Cursor::new(Vec::new()))
}
pub const fn len(&self) -> usize {
self.0.get_ref().len()
}
/// Use a heuristic to determine if we should shrink the write buffer.
#[inline]
fn should_shrink(&self) -> bool {
@@ -313,6 +392,19 @@ impl WriteBuf {
self.0.set_position(0);
}
/// Write a startup message.
pub fn startup(&mut self, params: &StartupMessageParams) {
self.0.get_mut().extend_from_slice(
StartupHeader {
len: big_endian::U32::new(params.params.len() as u32 + 9),
version: ProtocolVersion::new(3, 0),
}
.as_bytes(),
);
self.0.get_mut().extend_from_slice(params.params.as_bytes());
self.0.get_mut().push(0);
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
@@ -353,7 +445,7 @@ impl WriteBuf {
// Code: error_code
buf.put_u8(b'C');
buf.put_slice(&error_code);
buf.put_slice(&error_code.0);
buf.put_u8(0);
// Message: msg
@@ -468,11 +560,11 @@ pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
AuthenticationCleartextPassword,
BackendKeyData(CancelKeyData),
ParameterStatus {
name: &'a [u8],
value: &'a [u8],
},
#[cfg(test)]
ReadyForQuery,
NoticeResponse(&'a str),
NegotiateProtocolVersion {
@@ -494,17 +586,17 @@ impl BeMessage<'_> {
match self {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationOk => {
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(0));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationCleartextPassword => {
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(3));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
buf.write_raw(len + 2, b'R', |buf| {
buf.write_raw(len + 2, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods {
buf.put_slice(method.as_bytes());
@@ -515,24 +607,19 @@ impl BeMessage<'_> {
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
BeMessage::BackendKeyData(key_data) => {
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
BeMessage::NoticeResponse(msg) => {
@@ -564,15 +651,16 @@ impl BeMessage<'_> {
});
}
#[cfg(test)]
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::ReadyForQuery => {
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I'));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::NegotiateProtocolVersion { version, options } => {
let len: usize = options.iter().map(|o| o.len() + 1).sum();
buf.write_raw(8 + len, b'v', |buf| {
buf.write_raw(8 + len, BE_NEGOTIATE_MESSAGE.0, |buf| {
buf.put_slice(version.as_bytes());
buf.put_u32(options.len() as u32);
for option in options {

View File

@@ -2,7 +2,6 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
@@ -53,7 +52,6 @@ pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) user_info: ComputeUserInfo,
}
#[async_trait]
@@ -73,11 +71,7 @@ impl ConnectMechanism for TcpMechanism {
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(
node_info
.connect(ctx, &self.auth, config, self.user_info.clone())
.await,
)
permit.release_result(node_info.connect(ctx, &self.auth, config).await)
}
}

View File

@@ -10,6 +10,7 @@ use std::sync::Arc;
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::RawCancelToken;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
@@ -18,7 +19,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::cancellation::{self, CancellationHandler};
use crate::auth::backend::ComputeUserInfo;
use crate::cancellation::{self, CancelClosure, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
@@ -26,11 +28,11 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::pqproto::{CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::stream::{PostgresError, PqFeStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
@@ -155,7 +157,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
@@ -253,7 +255,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
@@ -273,9 +275,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
@@ -307,7 +309,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let hostname = mode.hostname(client.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
@@ -319,14 +321,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
&mut client,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
@@ -339,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream
return Err(client
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
@@ -357,27 +359,40 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: creds.info.clone(),
auth: auth_info,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let node = match res {
let mut node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
let session = cancellation_handler.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let cancel_closure =
prepare_client_connection(&mut node, session.key(), &mut client, creds.info).await?;
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&cancel_closure,
&config.connect_to_compute,
)
.await;
});
let client = client.flush_and_into_inner().await?;
let compute = node.stream.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
@@ -386,40 +401,75 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
client,
compute,
aux: node.aux,
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn prepare_client_connection(
node: &compute::PostgresConnection,
cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
pub(crate) async fn prepare_client_connection(
node: &mut compute::PostgresConnection,
key_data: &CancelKeyData,
stream: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
user_info: ComputeUserInfo,
) -> Result<CancelClosure, std::io::Error> {
use zerocopy::{FromBytes, IntoBytes};
use crate::pqproto::{BE_KEY_MESSAGE, BE_READY_MESSAGE};
let mut process_id = 0;
let mut secret_key = 0;
loop {
match node.stream.read_raw_be(1024).await {
// parse backend keys, and substitute our own.
Ok((tag @ BE_KEY_MESSAGE, msg)) => {
stream.write_raw(8, tag, |b| b.extend_from_slice(key_data.as_bytes()));
let key_data = CancelKeyData::read_from_bytes(msg)
.map_err(|_| std::io::Error::other("invalid msg len"))?;
process_id = (key_data.0.get() >> 32) as i32;
secret_key = (key_data.0.get() & 0xffff_ffff) as i32;
}
// ready for query, we're done :)
Ok((tag @ BE_READY_MESSAGE, msg)) => {
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
break;
}
// either a notice or a parameter status.
Ok((tag, msg)) => {
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
}
Err(PostgresError::Io(io)) => return Err(io),
Err(PostgresError::Error(e)) => return Err(std::io::Error::other(e)),
Err(_) => unreachable!("read_raw_be only returns IO or BackendError types"),
}
if stream.write_buf_len() > 512 {
stream.flush().await?;
}
}
// Forward all postgres connection params to the client.
for (name, value) in &node.params {
stream.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery);
Ok(CancelClosure::new(
node.socket_addr,
RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
node.hostname.clone(),
user_info,
))
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]

View File

@@ -1,10 +1,12 @@
use std::error::Error;
use std::io;
use bstr::ByteSlice;
use tokio::time;
use crate::compute;
use crate::config::RetryConfig;
use crate::stream::{BackendError, PostgresError};
pub(crate) trait CouldRetry {
/// Returns true if the error could be retried
@@ -96,10 +98,55 @@ impl ShouldRetryWakeCompute for postgres_client::Error {
}
}
impl CouldRetry for BackendError {
fn could_retry(&self) -> bool {
let (code, _message) = self.parse();
matches!(
code,
crate::pqproto::CONNECTION_FAILURE
| crate::pqproto::CONNECTION_EXCEPTION
| crate::pqproto::CONNECTION_DOES_NOT_EXIST
| crate::pqproto::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
impl ShouldRetryWakeCompute for BackendError {
fn should_retry_wake_compute(&self) -> bool {
let (code, message) = self.parse();
// Here are errors that happens after the user successfully authenticated to the database.
let non_retriable_pg_errors = matches!(
code,
crate::pqproto::TOO_MANY_CONNECTIONS
| crate::pqproto::OUT_OF_MEMORY
| crate::pqproto::SYNTAX_ERROR
| crate::pqproto::T_R_SERIALIZATION_FAILURE
| crate::pqproto::INVALID_CATALOG_NAME
| crate::pqproto::INVALID_SCHEMA_NAME
| crate::pqproto::INVALID_PARAMETER_VALUE,
);
if non_retriable_pg_errors {
return false;
}
// PGBouncer errors that should not trigger a wake_compute retry.
if code == crate::pqproto::PROTOCOL_VIOLATION {
// Source for the error message:
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
return message.contains_str("no more connections allowed (max_client_conn)");
}
true
}
}
impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::Postgres(PostgresError::Error(err)) => err.could_retry(),
compute::ConnectionError::Postgres(PostgresError::Io(err)) => err.could_retry(),
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
@@ -109,7 +156,12 @@ impl CouldRetry for compute::ConnectionError {
impl ShouldRetryWakeCompute for compute::ConnectionError {
fn should_retry_wake_compute(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(),
compute::ConnectionError::Postgres(PostgresError::Error(err)) => {
err.should_retry_wake_compute()
}
compute::ConnectionError::Postgres(PostgresError::Io(_)) => true,
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
// the cache entry was not checked for validity
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
_ => true,

View File

@@ -25,6 +25,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pqproto::BeMessage;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
@@ -122,7 +123,7 @@ fn generate_tls_config<'a>(
trait TestAuth: Sized {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
self,
stream: &mut PqStream<Stream<S>>,
stream: &mut PqFeStream<Stream<S>>,
) -> anyhow::Result<()> {
stream.write_message(BeMessage::AuthenticationOk);
Ok(())
@@ -151,7 +152,7 @@ impl Scram {
impl TestAuth for Scram {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
self,
stream: &mut PqStream<Stream<S>>,
stream: &mut PqFeStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
.authenticate()

View File

@@ -1,8 +1,4 @@
use std::io::ErrorKind;
use anyhow::Ok;
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
use crate::pqproto::CancelKeyData;
pub mod keyspace {
pub const CANCEL_PREFIX: &str = "cancel";
@@ -23,39 +19,12 @@ impl KeyPrefix {
}
}
}
#[allow(dead_code)]
pub(crate) fn as_str(&self) -> &'static str {
match self {
KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX,
}
}
}
#[allow(dead_code)]
pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result<KeyPrefix> {
let (prefix, key_str) = key.split_once(':').ok_or_else(|| {
anyhow::anyhow!(std::io::Error::new(
ErrorKind::InvalidData,
"missing prefix"
))
})?;
match prefix {
keyspace::CANCEL_PREFIX => {
let id = u64::from_str_radix(key_str, 16)?;
Ok(KeyPrefix::Cancel(id_to_cancel_key(id)))
}
_ => Err(anyhow::anyhow!(std::io::Error::new(
ErrorKind::InvalidData,
"unknown prefix"
))),
}
}
#[cfg(test)]
mod tests {
use crate::pqproto::id_to_cancel_key;
use super::*;
#[test]
@@ -65,16 +34,4 @@ mod tests {
let redis_key = cancel_key.build_redis_key();
assert_eq!(redis_key, "cancel:30390000d431");
}
#[test]
fn test_parse_redis_key() {
let redis_key = "cancel:30390000d431";
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
let KeyPrefix::Cancel(cancel_key) = key;
assert_eq!(ref_key, cancel_key);
}
}

View File

@@ -1,3 +1,6 @@
use std::time::Duration;
use futures::FutureExt;
use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
@@ -35,14 +38,11 @@ impl RedisKVClient {
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
self.client
.connect()
.boxed()
.await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
}
pub(crate) async fn query<T: FromRedisValue>(
@@ -54,15 +54,25 @@ impl RedisKVClient {
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match q.query(&mut self.client).await {
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {
tracing::error!("failed to run query: {e}");
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
}
redis::RetryMethod::RetryImmediately => {}
redis::RetryMethod::WaitAndRetry => {
// somewhat arbitrary.
tokio::time::sleep(Duration::from_millis(100)).await;
}
_ => Err(e)?,
}
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
Ok(q.query(&mut self.client).await?)
}
}

View File

@@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use super::{Mechanism, Step};
use crate::context::RequestContext;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::stream::PqStream;
use crate::stream::PqFeStream;
/// SASL authentication outcome.
/// It's much easier to match on those two variants
@@ -22,7 +22,7 @@ pub(crate) enum Outcome<R> {
pub async fn authenticate<S, F, M>(
ctx: &RequestContext,
stream: &mut PqStream<S>,
stream: &mut PqFeStream<S>,
mechanism: F,
) -> super::Result<Outcome<M::Output>>
where

View File

@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),

View File

@@ -1,351 +0,0 @@
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use rustls::ServerConfig;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::pqproto::{
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
read_message, read_startup,
};
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
///
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
stream: S,
read: Vec<u8>,
write: WriteBuf,
}
impl<S> PqStream<S> {
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
#[cfg(test)]
pub fn new_skip_handshake(stream: S) -> Self {
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
/// Construct a new libpq protocol wrapper and read the first startup message.
///
/// This is not cancel safe.
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
let startup = read_startup(&mut stream).await?;
Ok((
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
},
startup,
))
}
/// Tell the client that encryption is not supported.
///
/// This is not cancel safe
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
// N for No.
self.write.encryption(b'N');
self.flush().await?;
read_startup(&mut self.stream).await
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
"incorrect message tag, expected {:?}, got {:?}",
tag as char, actual_tag as char,
)));
}
Ok(msg)
}
/// Read a postgres password message, which will respect the max length requested.
/// This is not cancel safe.
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: u32 = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}
}
#[derive(Debug)]
pub struct ReportedError {
source: anyhow::Error,
error_kind: ErrorKind,
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),
error_kind,
}
}
}
impl std::fmt::Display for ReportedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.source.fmt(f)
}
}
impl std::error::Error for ReportedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.source()
}
}
impl ReportableError for ReportedError {
fn get_error_kind(&self) -> ErrorKind {
self.error_kind
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Tell the client that we are willing to accept SSL.
/// This is not cancel safe
pub async fn accept_tls(mut self) -> io::Result<S> {
// S for SSL.
self.write.encryption(b'S');
self.flush().await?;
Ok(self.stream)
}
/// Assert that we are using direct TLS.
pub fn accept_direct_tls(self) -> S {
self.stream
}
/// Write a raw message to the internal buffer.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
self.write.write_raw(size_hint, tag, f);
}
/// Write the message into an internal buffer
pub fn write_message(&mut self, message: BeMessage<'_>) {
message.write_message(&mut self.write);
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush(&mut self) -> io::Result<()> {
self.stream.write_all_buf(&mut self.write).await?;
self.write.reset();
self.stream.flush().await?;
Ok(())
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
self.flush().await?;
Ok(self.stream)
}
/// Write the error message to the client, then re-throw it.
///
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub(crate) async fn throw_error<E>(
&mut self,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> ReportedError
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
kind = error_kind.to_metric_label(),
msg,
"forwarding error to user"
);
}
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
self.flush()
.await
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
ReportedError::new(error)
}
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
Stream::Tls {
tls_server_end_point,
..
} => *tls_server_end_point,
}
}
}
#[derive(Debug, Error)]
#[error("Can't upgrade TLS stream")]
pub enum StreamUpgradeError {
#[error("Bad state reached: can't upgrade TLS stream")]
AlreadyTls,
#[error("Can't upgrade stream: IO error: {0}")]
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
cfg: Arc<ServerConfig>,
record_handshake_error: bool,
) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
}
}
}

168
proxy/src/stream/mod.rs Normal file
View File

@@ -0,0 +1,168 @@
mod pq_backend;
mod pq_frontend;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
pub use pq_backend::{BackendError, PostgresError, PqBeStream};
pub use pq_frontend::PqFeStream;
use rustls::ServerConfig;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::tls::TlsServerEndPoint;
#[derive(Debug)]
pub struct ReportedError {
source: anyhow::Error,
error_kind: ErrorKind,
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),
error_kind,
}
}
}
impl std::fmt::Display for ReportedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.source.fmt(f)
}
}
impl std::error::Error for ReportedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.source()
}
}
impl ReportableError for ReportedError {
fn get_error_kind(&self) -> ErrorKind {
self.error_kind
}
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
Stream::Tls {
tls_server_end_point,
..
} => *tls_server_end_point,
}
}
}
#[derive(Debug, Error)]
#[error("Can't upgrade TLS stream")]
pub enum StreamUpgradeError {
#[error("Bad state reached: can't upgrade TLS stream")]
AlreadyTls,
#[error("Can't upgrade stream: IO error: {0}")]
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
cfg: Arc<ServerConfig>,
record_handshake_error: bool,
) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
}
}
}

View File

@@ -0,0 +1,165 @@
//! Postgres connection from backend, proxy is the frontend.
use std::io;
use bytes::Bytes;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::pqproto::{
AuthTag, BE_AUTH_MESSAGE, BE_ERR_MESSAGE, BeTag, ErrorCode, SQLSTATE_INTERNAL_ERROR,
StartupMessageParams, WriteBuf, read_message,
};
/// Stream wrapper which implements libpq's protocol.
pub struct PqBeStream<S> {
stream: S,
read: Vec<u8>,
write: WriteBuf,
}
impl<S> PqBeStream<S> {
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Construct a new libpq protocol wrapper and write the first startup message.
pub fn new(stream: S, params: &StartupMessageParams) -> Self {
let mut write = WriteBuf::new();
write.startup(params);
Self {
stream,
read: Vec::new(),
write,
}
}
}
impl<S: AsyncRead + Unpin> PqBeStream<S> {
/// Read a raw postgres packet from the backend, which will respect the max length requested,
/// as well as handling postgres error messages.
///
/// This is not cancel safe.
pub async fn read_raw_be(&mut self, max: u32) -> Result<(BeTag, &mut [u8]), PostgresError> {
let (tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
match BeTag(tag) {
BE_ERR_MESSAGE => Err(PostgresError::Error(BackendError {
data: msg.to_vec().into(),
})),
tag => Ok((tag, msg)),
}
}
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_be_expect(
&mut self,
tag: BeTag,
max: u32,
) -> Result<&mut [u8], PostgresError> {
let (actual_tag, msg) = self.read_raw_be(max).await?;
if actual_tag != tag {
return Err(PostgresError::Unexpected(UnexpectedMessage {
expected: tag,
tag: actual_tag,
data: msg.to_vec().into(),
}));
}
Ok(msg)
}
/// Read a postgres backend auth message.
/// This is not cancel safe.
pub async fn read_auth_message(&mut self) -> Result<(AuthTag, &mut [u8]), PostgresError> {
const MAX_AUTH_LENGTH: u32 = 512;
self.read_raw_be_expect(BE_AUTH_MESSAGE, MAX_AUTH_LENGTH)
.await?
.split_first_chunk_mut()
.map(|(tag, msg)| (AuthTag(i32::from_be_bytes(*tag)), msg))
.ok_or(PostgresError::InvalidAuthMessage)
}
}
impl<S: AsyncWrite + Unpin> PqBeStream<S> {
/// Write a raw message to the internal buffer.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
self.write.write_raw(size_hint, tag, f);
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush(&mut self) -> io::Result<()> {
self.stream.write_all_buf(&mut self.write).await?;
self.write.reset();
self.stream.flush().await?;
Ok(())
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
self.flush().await?;
Ok(self.stream)
}
}
#[derive(Debug, Error)]
pub enum PostgresError {
#[error("postgres responded with error {0}")]
Error(#[from] BackendError),
#[error("postgres responded with an unexpected message: {0}")]
Unexpected(#[from] UnexpectedMessage),
#[error("postgres responded with an invalid authentication message")]
InvalidAuthMessage,
#[error("IO error from compute: {0}")]
Io(#[from] io::Error),
}
#[derive(Debug, Error)]
#[error("expected {expected}, got {tag} with data {data:?}")]
pub struct UnexpectedMessage {
expected: BeTag,
tag: BeTag,
data: Bytes,
}
pub struct BackendError {
data: Bytes,
}
impl BackendError {
pub fn parse(&self) -> (ErrorCode, &[u8]) {
let mut code = &[] as &[u8];
let mut message = &[] as &[u8];
for param in self.data.split(|b| *b == 0) {
match param {
[b'M', rest @ ..] => message = rest,
[b'C', rest @ ..] => code = rest,
_ => {}
}
}
let code = code.try_into().map_or(SQLSTATE_INTERNAL_ERROR, ErrorCode);
(code, message)
}
}
impl std::fmt::Debug for BackendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl std::fmt::Display for BackendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self.data)
}
}
impl std::error::Error for BackendError {}

View File

@@ -0,0 +1,197 @@
//! Postgres connection from frontend, proxy is the backend.
use std::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::error::{ErrorKind, UserFacingError};
use crate::pqproto::{
BeMessage, BeTag, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR,
WriteBuf, read_message, read_startup,
};
use crate::stream::ReportedError;
/// Stream wrapper which implements libpq's protocol.
pub struct PqFeStream<S> {
stream: S,
read: Vec<u8>,
write: WriteBuf,
}
impl<S> PqFeStream<S> {
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
#[cfg(test)]
pub fn new_skip_handshake(stream: S) -> Self {
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
}
}
pub fn write_buf_len(&self) -> usize {
self.write.len()
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> PqFeStream<S> {
/// Construct a new libpq protocol wrapper and read the first startup message.
///
/// This is not cancel safe.
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
let startup = read_startup(&mut stream).await?;
Ok((
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
},
startup,
))
}
/// Tell the client that encryption is not supported.
///
/// This is not cancel safe
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
// N for No.
self.write.encryption(b'N');
self.flush().await?;
read_startup(&mut self.stream).await
}
}
impl<S: AsyncRead + Unpin> PqFeStream<S> {
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: FeTag, max: u32) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
let actual_tag = FeTag(actual_tag);
if actual_tag != tag {
return Err(io::Error::other(format!(
"incorrect message tag, expected {tag}, got {actual_tag}",
)));
}
Ok(msg)
}
/// Read a postgres password message, which will respect the max length requested.
/// This is not cancel safe.
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: u32 = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}
}
impl<S: AsyncWrite + Unpin> PqFeStream<S> {
/// Tell the client that we are willing to accept SSL.
/// This is not cancel safe
pub async fn accept_tls(mut self) -> io::Result<S> {
// S for SSL.
self.write.encryption(b'S');
self.flush().await?;
Ok(self.stream)
}
/// Assert that we are using direct TLS.
pub fn accept_direct_tls(self) -> S {
self.stream
}
/// Write a raw message to the internal buffer.
pub fn write_raw(&mut self, size_hint: usize, tag: BeTag, f: impl FnOnce(&mut Vec<u8>)) {
self.write.write_raw(size_hint, tag.0, f);
}
/// Write the message into an internal buffer
pub fn write_message(&mut self, message: BeMessage<'_>) {
message.write_message(&mut self.write);
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush(&mut self) -> io::Result<()> {
self.stream.write_all_buf(&mut self.write).await?;
self.write.reset();
self.stream.flush().await?;
Ok(())
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
self.flush().await?;
Ok(self.stream)
}
/// Write the error message to the client, then re-throw it.
///
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub(crate) async fn throw_error<E>(
&mut self,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> ReportedError
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
kind = error_kind.to_metric_label(),
%error,
msg,
"forwarding error to user"
);
}
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
self.flush()
.await
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
ReportedError::new(error)
}
}