mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-28 02:20:42 +00:00
Compare commits
15 Commits
conrad/min
...
conrad/rem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8ddf5c850 | ||
|
|
7c469b30aa | ||
|
|
a78a52acb5 | ||
|
|
3370e8cb00 | ||
|
|
f37a558280 | ||
|
|
744011437a | ||
|
|
a10d26a083 | ||
|
|
aece520365 | ||
|
|
9017811d61 | ||
|
|
551a33aa04 | ||
|
|
95216ae6ec | ||
|
|
a3a10d1839 | ||
|
|
1b935b1958 | ||
|
|
3f16ca2c18 | ||
|
|
67b94c5992 |
16
Cargo.lock
generated
16
Cargo.lock
generated
@@ -753,6 +753,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http 1.1.0",
|
||||
@@ -761,6 +762,8 @@ dependencies = [
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_html_form",
|
||||
"serde_path_to_error",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -6422,6 +6425,19 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_html_form"
|
||||
version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"indexmap 2.9.0",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.125"
|
||||
|
||||
@@ -71,7 +71,7 @@ aws-credential-types = "1.2.0"
|
||||
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
|
||||
aws-types = "1.3"
|
||||
axum = { version = "0.8.1", features = ["ws"] }
|
||||
axum-extra = { version = "0.10.0", features = ["typed-header"] }
|
||||
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
|
||||
base64 = "0.13.0"
|
||||
bincode = "1.3"
|
||||
bindgen = "0.71"
|
||||
|
||||
@@ -785,7 +785,7 @@ impl ComputeNode {
|
||||
self.spawn_extension_stats_task();
|
||||
|
||||
if pspec.spec.autoprewarm {
|
||||
self.prewarm_lfc();
|
||||
self.prewarm_lfc(None);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -25,11 +25,16 @@ struct EndpointStoragePair {
|
||||
}
|
||||
|
||||
const KEY: &str = "lfc_state";
|
||||
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
|
||||
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
|
||||
bail!("pspec.endpoint_id missing")
|
||||
impl EndpointStoragePair {
|
||||
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
|
||||
/// If not None, takes precedence over pspec.spec.endpoint_id
|
||||
fn from_spec_and_endpoint(
|
||||
pspec: &crate::compute::ParsedSpec,
|
||||
endpoint_id: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
|
||||
let Some(ref endpoint_id) = endpoint_id else {
|
||||
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
|
||||
};
|
||||
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
|
||||
bail!("pspec.endpoint_storage_addr missing")
|
||||
@@ -84,7 +89,7 @@ impl ComputeNode {
|
||||
}
|
||||
|
||||
/// Returns false if there is a prewarm request ongoing, true otherwise
|
||||
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
|
||||
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
|
||||
crate::metrics::LFC_PREWARM_REQUESTS.inc();
|
||||
{
|
||||
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
|
||||
@@ -97,7 +102,7 @@ impl ComputeNode {
|
||||
|
||||
let cloned = self.clone();
|
||||
spawn(async move {
|
||||
let Err(err) = cloned.prewarm_impl().await else {
|
||||
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
|
||||
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
|
||||
return;
|
||||
};
|
||||
@@ -109,13 +114,14 @@ impl ComputeNode {
|
||||
true
|
||||
}
|
||||
|
||||
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
|
||||
/// from_endpoint: None for endpoint managed by this compute_ctl
|
||||
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
|
||||
let state = self.state.lock().unwrap();
|
||||
state.pspec.as_ref().unwrap().try_into()
|
||||
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
|
||||
}
|
||||
|
||||
async fn prewarm_impl(&self) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
|
||||
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
|
||||
info!(%url, "requesting LFC state from endpoint storage");
|
||||
|
||||
let request = Client::new().get(&url).bearer_auth(token);
|
||||
@@ -173,7 +179,7 @@ impl ComputeNode {
|
||||
}
|
||||
|
||||
async fn offload_lfc_impl(&self) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
|
||||
info!(%url, "requesting LFC state from postgres");
|
||||
|
||||
let mut compressed = Vec::new();
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
|
||||
use crate::http::JsonResponse;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Json, http::StatusCode};
|
||||
use axum_extra::extract::OptionalQuery;
|
||||
use compute_api::responses::LfcOffloadState;
|
||||
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
|
||||
|
||||
@@ -16,8 +17,16 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
|
||||
Json(compute.lfc_offload_state())
|
||||
}
|
||||
|
||||
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
|
||||
if compute.prewarm_lfc() {
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct PrewarmQuery {
|
||||
pub from_endpoint: String,
|
||||
}
|
||||
|
||||
pub(in crate::http) async fn prewarm(
|
||||
compute: Compute,
|
||||
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
|
||||
) -> Response {
|
||||
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
|
||||
StatusCode::ACCEPTED.into_response()
|
||||
} else {
|
||||
JsonResponse::error(
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::config::Host;
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::{Error, cancel_query_raw, connect_socket};
|
||||
use crate::{Error, cancel_query_raw, connect_socket, connect_tls};
|
||||
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
config: SocketConfig,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
@@ -17,16 +14,6 @@ pub(crate) async fn cancel_query<T>(
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = match config {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"unknown host",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let hostname = match &config.host {
|
||||
Host::Tcp(host) => &**host,
|
||||
};
|
||||
@@ -42,5 +29,6 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
|
||||
let stream = connect_tls::connect_tls(socket, config.ssl_mode, tls).await?;
|
||||
cancel_query_raw::cancel_query_raw(stream, process_id, secret_key).await
|
||||
}
|
||||
|
||||
@@ -2,23 +2,16 @@ use bytes::BytesMut;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::config::SslMode;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Error, connect_tls};
|
||||
use crate::Error;
|
||||
|
||||
pub async fn cancel_query_raw<S, T>(
|
||||
stream: S,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
pub async fn cancel_query_raw<S>(
|
||||
mut stream: S,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::cancel_request(process_id, secret_key, &mut buf);
|
||||
|
||||
|
||||
@@ -3,16 +3,21 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::SslMode;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::{Error, cancel_query, cancel_query_raw};
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Clone)]
|
||||
pub struct CancelToken {
|
||||
pub socket_config: Option<SocketConfig>,
|
||||
pub ssl_mode: SslMode,
|
||||
pub socket_config: SocketConfig,
|
||||
pub raw: RawCancelToken,
|
||||
}
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RawCancelToken {
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
@@ -36,28 +41,21 @@ impl CancelToken {
|
||||
{
|
||||
cancel_query::cancel_query(
|
||||
self.socket_config.clone(),
|
||||
self.ssl_mode,
|
||||
tls,
|
||||
self.process_id,
|
||||
self.secret_key,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
|
||||
/// connection itself.
|
||||
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
cancel_query_raw::cancel_query_raw(
|
||||
stream,
|
||||
self.ssl_mode,
|
||||
tls,
|
||||
self.process_id,
|
||||
self.secret_key,
|
||||
self.raw.process_id,
|
||||
self.raw.secret_key,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl RawCancelToken {
|
||||
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
|
||||
/// connection itself.
|
||||
pub async fn cancel_query_raw<S>(&self, stream: S) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
cancel_query_raw::cancel_query_raw(stream, self.process_id, self.secret_key).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::cancel_token::RawCancelToken;
|
||||
use crate::codec::{BackendMessages, FrontendMessage};
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::query::RowStream;
|
||||
@@ -166,6 +167,7 @@ pub struct SocketConfig {
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub connect_timeout: Option<Duration>,
|
||||
pub ssl_mode: SslMode,
|
||||
}
|
||||
|
||||
/// An asynchronous PostgreSQL client.
|
||||
@@ -177,7 +179,6 @@ pub struct Client {
|
||||
cached_typeinfo: CachedTypeInfo,
|
||||
|
||||
socket_config: SocketConfig,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
}
|
||||
@@ -187,7 +188,6 @@ impl Client {
|
||||
sender: mpsc::UnboundedSender<FrontendMessage>,
|
||||
receiver: mpsc::Receiver<BackendMessages>,
|
||||
socket_config: SocketConfig,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Client {
|
||||
@@ -205,7 +205,6 @@ impl Client {
|
||||
cached_typeinfo: Default::default(),
|
||||
|
||||
socket_config,
|
||||
ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
}
|
||||
@@ -331,10 +330,11 @@ impl Client {
|
||||
/// connection associated with this client.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
CancelToken {
|
||||
socket_config: Some(self.socket_config.clone()),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
socket_config: self.socket_config.clone(),
|
||||
raw: RawCancelToken {
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ where
|
||||
host: host.clone(),
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
ssl_mode: config.ssl_mode,
|
||||
};
|
||||
|
||||
let (client_tx, conn_rx) = mpsc::unbounded_channel();
|
||||
@@ -65,7 +66,6 @@ where
|
||||
client_tx,
|
||||
client_rx,
|
||||
socket_config,
|
||||
config.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
|
||||
pub use crate::cancel_token::CancelToken;
|
||||
pub use crate::cancel_token::{CancelToken, RawCancelToken};
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
|
||||
@@ -14,12 +14,13 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::control_plane::errors::GetAuthInfoError;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
|
||||
RoleAccessControl,
|
||||
@@ -230,11 +231,8 @@ async fn auth_quirks(
|
||||
config.is_vpc_acccess_proxy,
|
||||
)?;
|
||||
|
||||
let endpoint = EndpointIdInt::from(&info.endpoint);
|
||||
let rate_limit_config = None;
|
||||
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
|
||||
|
||||
let role_access = api
|
||||
.get_role_access_control(ctx, &info.endpoint, &info.user)
|
||||
.await?;
|
||||
@@ -401,6 +399,7 @@ impl Backend<'_, ComputeUserInfo> {
|
||||
allowed_ips: Arc::new(vec![]),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -439,6 +438,7 @@ mod tests {
|
||||
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
|
||||
};
|
||||
@@ -477,6 +477,7 @@ mod tests {
|
||||
allowed_ips: Arc::new(self.ips.clone()),
|
||||
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
|
||||
flags: self.access_blocker_flags,
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
146
proxy/src/batch.rs
Normal file
146
proxy/src/batch.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
//! Batch processing system based on intrusive linked lists.
|
||||
//!
|
||||
//! Enqueuing a batch job requires no allocations, with
|
||||
//! direct support for cancelling jobs early.
|
||||
use std::collections::BTreeMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use futures::future::Either;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
|
||||
use crate::ext::LockExt;
|
||||
|
||||
pub trait QueueProcessing: Send + 'static {
|
||||
type Req: Send + 'static;
|
||||
type Res: Send;
|
||||
|
||||
/// Get the desired batch size.
|
||||
fn batch_size(&self, queue_size: usize) -> usize;
|
||||
|
||||
/// This applies a full batch of events.
|
||||
/// Must respond with a full batch of replies.
|
||||
///
|
||||
/// If this apply can error, it's expected that errors be forwarded to each Self::Res.
|
||||
///
|
||||
/// Batching does not need to happen atomically.
|
||||
fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
|
||||
}
|
||||
|
||||
pub struct BatchQueue<P: QueueProcessing> {
|
||||
processor: tokio::sync::Mutex<P>,
|
||||
inner: Mutex<BatchQueueInner<P>>,
|
||||
}
|
||||
|
||||
struct BatchJob<P: QueueProcessing> {
|
||||
req: P::Req,
|
||||
res: tokio::sync::oneshot::Sender<P::Res>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueue<P> {
|
||||
pub fn new(p: P) -> Self {
|
||||
Self {
|
||||
processor: tokio::sync::Mutex::new(p),
|
||||
inner: Mutex::new(BatchQueueInner {
|
||||
version: 0,
|
||||
queue: BTreeMap::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call(&self, req: P::Req) -> P::Res {
|
||||
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
|
||||
let guard = scopeguard::guard(id, move |id| {
|
||||
let mut inner = self.inner.lock_propagate_poison();
|
||||
if inner.queue.remove(&id).is_some() {
|
||||
tracing::debug!("batched task cancelled before completion");
|
||||
}
|
||||
});
|
||||
|
||||
let resp = loop {
|
||||
// try become the leader, or try wait for success.
|
||||
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await
|
||||
{
|
||||
// we got the resp.
|
||||
Either::Left((resp, _)) => break resp.ok(),
|
||||
// we are the leader.
|
||||
Either::Right((p, rx_)) => {
|
||||
rx = rx_;
|
||||
p
|
||||
}
|
||||
};
|
||||
|
||||
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
|
||||
|
||||
// apply a batch.
|
||||
let values = processor.apply(reqs).await;
|
||||
|
||||
// send response values.
|
||||
for (tx, value) in std::iter::zip(resps, values) {
|
||||
// sender hung up but that's fine.
|
||||
drop(tx.send(value));
|
||||
}
|
||||
|
||||
match rx.try_recv() {
|
||||
Ok(resp) => break Some(resp),
|
||||
Err(TryRecvError::Closed) => break None,
|
||||
// edge case - there was a race condition where
|
||||
// we became the leader but were not in the batch.
|
||||
//
|
||||
// Example:
|
||||
// thread 1: register job id=1
|
||||
// thread 2: register job id=2
|
||||
// thread 2: processor.lock().await
|
||||
// thread 1: processor.lock().await
|
||||
// thread 2: becomes leader, batch_size=1, jobs=[1].
|
||||
Err(TryRecvError::Empty) => {}
|
||||
}
|
||||
};
|
||||
|
||||
// already removed.
|
||||
ScopeGuard::into_inner(guard);
|
||||
|
||||
resp.expect("no response found. batch processer should not panic")
|
||||
}
|
||||
}
|
||||
|
||||
struct BatchQueueInner<P: QueueProcessing> {
|
||||
version: u64,
|
||||
queue: BTreeMap<u64, BatchJob<P>>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueueInner<P> {
|
||||
fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let id = self.version;
|
||||
|
||||
// Overflow concern:
|
||||
// This is a u64, and we might enqueue 2^16 tasks per second.
|
||||
// This gives us 2^48 seconds (9 million years).
|
||||
// Even if this does overflow, it will not break, but some
|
||||
// jobs with the higher version might never get prioritised.
|
||||
self.version += 1;
|
||||
|
||||
self.queue.insert(id, BatchJob { req, res: tx });
|
||||
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
|
||||
let batch_size = p.batch_size(self.queue.len());
|
||||
let mut reqs = Vec::with_capacity(batch_size);
|
||||
let mut resps = Vec::with_capacity(batch_size);
|
||||
|
||||
while reqs.len() < batch_size {
|
||||
let Some((_, job)) = self.queue.pop_first() else {
|
||||
break;
|
||||
};
|
||||
reqs.push(job.req);
|
||||
resps.push(job.res);
|
||||
}
|
||||
|
||||
(reqs, resps)
|
||||
}
|
||||
}
|
||||
@@ -201,7 +201,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
auth_backend,
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
|
||||
Arc::new(CancellationHandler::new()),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ use utils::{project_build_tag, project_git_version};
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
|
||||
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
|
||||
use crate::batch::BatchQueue;
|
||||
use crate::cancellation::{CancellationHandler, CancellationProcessor};
|
||||
use crate::config::{
|
||||
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
|
||||
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
|
||||
@@ -390,13 +391,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new());
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
|
||||
@@ -523,14 +518,10 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(redis_kv_client);
|
||||
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
|
||||
client: redis_kv_client,
|
||||
batch_size: args.cancellation_batch_size,
|
||||
}));
|
||||
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
|
||||
4
proxy/src/cache/project_info.rs
vendored
4
proxy/src/cache/project_info.rs
vendored
@@ -364,6 +364,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::types::ProjectId;
|
||||
@@ -399,6 +400,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret1.clone(),
|
||||
@@ -414,6 +416,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret2.clone(),
|
||||
@@ -439,6 +442,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret3.clone(),
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use anyhow::anyhow;
|
||||
use futures::FutureExt;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::CancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::RawCancelToken;
|
||||
use redis::{Cmd, FromRedisValue, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::auth::AuthError;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::batch::{BatchQueue, QueueProcessing};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::ControlPlaneApi;
|
||||
use crate::error::ReportableError;
|
||||
@@ -27,46 +29,36 @@ use crate::redis::kv_ops::RedisKVClient;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
|
||||
const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
|
||||
const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelKeyOp {
|
||||
StoreCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
value: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
expire: i64, // TTL for key
|
||||
key: CancelKeyData,
|
||||
value: Box<str>,
|
||||
expire: std::time::Duration,
|
||||
},
|
||||
GetCancelData {
|
||||
key: String,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
key: CancelKeyData,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Pipeline {
|
||||
inner: redis::Pipeline,
|
||||
replies: Vec<CancelReplyOp>,
|
||||
replies: usize,
|
||||
}
|
||||
|
||||
impl Pipeline {
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self {
|
||||
inner: redis::Pipeline::with_capacity(n),
|
||||
replies: Vec::with_capacity(n),
|
||||
replies: 0,
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(&mut self, client: &mut RedisKVClient) {
|
||||
let responses = self.replies.len();
|
||||
async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
|
||||
let responses = self.replies;
|
||||
let batch_size = self.inner.len();
|
||||
|
||||
match client.query(&self.inner).await {
|
||||
@@ -76,176 +68,73 @@ impl Pipeline {
|
||||
batch_size,
|
||||
responses, "successfully completed cancellation jobs",
|
||||
);
|
||||
for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
|
||||
reply.send_value(value);
|
||||
}
|
||||
values.into_iter().map(Ok).collect()
|
||||
}
|
||||
Ok(value) => {
|
||||
error!(batch_size, ?value, "unexpected redis return value");
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("incorrect response type from redis"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
Err(err) => {
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
self.inner.clear();
|
||||
self.replies.clear();
|
||||
}
|
||||
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd);
|
||||
self.replies.push(reply);
|
||||
self.replies += 1;
|
||||
}
|
||||
|
||||
fn add_command_no_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd).ignore();
|
||||
}
|
||||
|
||||
fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
|
||||
match reply {
|
||||
Some(reply) => self.add_command_with_reply(cmd, reply),
|
||||
None => self.add_command_no_reply(cmd),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelKeyOp {
|
||||
fn register(self, pipe: &mut Pipeline) {
|
||||
fn register(&self, pipe: &mut Pipeline) {
|
||||
#[allow(clippy::used_underscore_binding)]
|
||||
match self {
|
||||
CancelKeyOp::StoreCancelKey {
|
||||
key,
|
||||
field,
|
||||
value,
|
||||
resp_tx,
|
||||
_guard,
|
||||
expire,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hset(&key, field, value), reply);
|
||||
pipe.add_command_no_reply(Cmd::expire(key, expire));
|
||||
CancelKeyOp::StoreCancelKey { key, value, expire } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
|
||||
pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
|
||||
}
|
||||
CancelKeyOp::GetCancelData {
|
||||
key,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
|
||||
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
|
||||
}
|
||||
CancelKeyOp::RemoveCancelKey {
|
||||
key,
|
||||
field,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hdel(key, field), reply);
|
||||
CancelKeyOp::GetCancelData { key } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hget(key, "data"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelReplyOp {
|
||||
StoreCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
GetCancelData {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
pub struct CancellationProcessor {
|
||||
pub client: RedisKVClient,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl CancelReplyOp {
|
||||
fn send_err(self, e: anyhow::Error) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
impl QueueProcessing for CancellationProcessor {
|
||||
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
|
||||
type Res = anyhow::Result<redis::Value>;
|
||||
|
||||
fn batch_size(&self, _queue_size: usize) -> usize {
|
||||
self.batch_size
|
||||
}
|
||||
|
||||
fn send_value(self, v: redis::Value) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Running as a separate task to accept messages through the rx channel
|
||||
pub async fn handle_cancel_messages(
|
||||
client: &mut RedisKVClient,
|
||||
mut rx: mpsc::Receiver<CancelKeyOp>,
|
||||
batch_size: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
let mut pipeline = Pipeline::with_capacity(batch_size);
|
||||
|
||||
loop {
|
||||
if rx.recv_many(&mut batch, batch_size).await == 0 {
|
||||
warn!("shutting down cancellation queue");
|
||||
break Ok(());
|
||||
}
|
||||
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
|
||||
let mut pipeline = Pipeline::with_capacity(batch.len());
|
||||
|
||||
let batch_size = batch.len();
|
||||
debug!(batch_size, "running cancellation jobs");
|
||||
|
||||
for msg in batch.drain(..) {
|
||||
msg.register(&mut pipeline);
|
||||
for (_, op) in &batch {
|
||||
op.register(&mut pipeline);
|
||||
}
|
||||
|
||||
pipeline.execute(client).await;
|
||||
pipeline.execute(&mut self.client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,10 +142,9 @@ pub async fn handle_cancel_messages(
|
||||
///
|
||||
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler {
|
||||
compute_config: &'static ComputeConfig,
|
||||
// rate limiter of cancellation requests
|
||||
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
|
||||
tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -296,13 +184,9 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
|
||||
impl CancellationHandler {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>,
|
||||
) -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
tx,
|
||||
tx: OnceLock::new(),
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
@@ -312,7 +196,14 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
|
||||
self.tx
|
||||
.set(queue)
|
||||
.map_err(|_| {})
|
||||
.expect("cancellation queue should be registered once");
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -322,14 +213,10 @@ impl CancellationHandler {
|
||||
|
||||
let key: CancelKeyData = rand::random();
|
||||
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
|
||||
debug!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,62 +224,43 @@ impl CancellationHandler {
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
) -> Result<Option<CancelClosure>, CancelError> {
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGet);
|
||||
let op = CancelKeyOp::GetCancelData { key };
|
||||
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
let op = CancelKeyOp::GetCancelData {
|
||||
key: redis_key,
|
||||
resp_tx,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGetAll),
|
||||
};
|
||||
|
||||
let Some(tx) = &self.tx else {
|
||||
let Some(tx) = self.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
tx.try_send(op)
|
||||
const TIMEOUT: Duration = Duration::from_secs(5);
|
||||
let result = timeout(TIMEOUT, tx.call((guard, op)))
|
||||
.await
|
||||
.map_err(|_| {
|
||||
tracing::warn!("timed out waiting to receive GetCancelData response");
|
||||
CancelError::RateLimit
|
||||
})?
|
||||
.map_err(|e| {
|
||||
tracing::warn!("failed to send GetCancelData for {key}: {e}");
|
||||
})
|
||||
.map_err(|()| CancelError::InternalError)?;
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let result = resp_rx.await.map_err(|e| {
|
||||
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state_str: Option<String> = match result {
|
||||
Ok(mut state) => {
|
||||
if state.len() == 1 {
|
||||
Some(state.remove(0).1)
|
||||
} else {
|
||||
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to receive cancel state from redis: {e}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
};
|
||||
let cancel_closure: CancelClosure =
|
||||
serde_json::from_str(&cancel_state_str).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state: Option<CancelClosure> = match cancel_state_str {
|
||||
Some(state) => {
|
||||
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
Some(cancel_closure)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(cancel_state)
|
||||
Ok(Some(cancel_closure))
|
||||
}
|
||||
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
/// check_allowed - if true, check if the IP is allowed to cancel the query.
|
||||
@@ -460,17 +328,17 @@ impl CancellationHandler {
|
||||
kind: crate::metrics::CancellationOutcome::Found,
|
||||
});
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query(self.compute_config).await
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String, // for pg_sni router
|
||||
user_info: ComputeUserInfo,
|
||||
}
|
||||
@@ -478,7 +346,7 @@ pub struct CancelClosure {
|
||||
impl CancelClosure {
|
||||
pub(crate) fn new(
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Self {
|
||||
@@ -490,19 +358,9 @@ impl CancelClosure {
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), CancelError> {
|
||||
pub(crate) async fn try_cancel_query(&self) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
compute_config,
|
||||
&self.hostname,
|
||||
)
|
||||
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
|
||||
|
||||
self.cancel_token.cancel_query_raw(socket, tls).await?;
|
||||
self.cancel_token.cancel_query_raw(socket).await?;
|
||||
debug!("query was cancelled");
|
||||
Ok(())
|
||||
}
|
||||
@@ -512,7 +370,6 @@ impl CancelClosure {
|
||||
pub(crate) struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
redis_key: String,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
@@ -521,60 +378,61 @@ impl Session {
|
||||
&self.key
|
||||
}
|
||||
|
||||
// Send the store key op to the cancellation handler and set TTL for the key
|
||||
pub(crate) fn write_cancel_key(
|
||||
/// Ensure the cancel key is continously refreshed,
|
||||
/// but stop when the channel is dropped.
|
||||
pub(crate) async fn maintain_cancel_key(
|
||||
&self,
|
||||
cancel_closure: CancelClosure,
|
||||
) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
session_id: uuid::Uuid,
|
||||
cancel: tokio::sync::oneshot::Receiver<Infallible>,
|
||||
cancel_closure: &CancelClosure,
|
||||
) {
|
||||
futures::future::select(
|
||||
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
|
||||
cancel,
|
||||
)
|
||||
.await;
|
||||
|
||||
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
|
||||
tracing::warn!("failed to serialize cancel closure: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
value: closure_json,
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HSet),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
if let Err(err) = cancel_closure.try_cancel_query().boxed().await {
|
||||
tracing::warn!(
|
||||
?session_id,
|
||||
?err,
|
||||
"could not cancel the query in the database"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
// Ensure the cancel key is continously refreshed.
|
||||
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
|
||||
let Some(tx) = self.cancellation_handler.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
// don't exit, as we only want to exit if cancelled externally.
|
||||
std::future::pending().await
|
||||
};
|
||||
|
||||
let op = CancelKeyOp::RemoveCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
let closure_json = serde_json::to_string(&cancel_closure)
|
||||
.expect("serialising to json string should not fail")
|
||||
.into_boxed_str();
|
||||
|
||||
loop {
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HDel),
|
||||
};
|
||||
.guard(RedisMsgKind::HSet);
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.key,
|
||||
value: closure_json.clone(),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"registering cancellation key"
|
||||
);
|
||||
|
||||
if tx.call((guard, op)).await.is_ok() {
|
||||
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, NoTls, RawConnection};
|
||||
use postgres_client::{NoTls, RawCancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
@@ -265,7 +265,8 @@ impl ConnectInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
@@ -279,7 +280,7 @@ pub(crate) struct PostgresConnection {
|
||||
/// Notices received from compute after authenticating
|
||||
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
pub(crate) guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
@@ -327,9 +328,7 @@ impl ConnectInfo {
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(
|
||||
socket_addr,
|
||||
CancelToken {
|
||||
socket_config: None,
|
||||
ssl_mode: self.ssl_mode,
|
||||
RawCancelToken {
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
@@ -343,7 +342,7 @@ impl ConnectInfo {
|
||||
delayed_notice,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -120,7 +120,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
@@ -232,22 +232,30 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(session_id, cancel, &node.cancel_closure)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -146,6 +146,7 @@ impl NeonControlPlaneClient {
|
||||
public_access_blocked: block_public_connections,
|
||||
vpc_access_blocked: block_vpc_connections,
|
||||
},
|
||||
rate_limits: body.rate_limits,
|
||||
})
|
||||
}
|
||||
.inspect_err(|e| tracing::debug!(error = ?e))
|
||||
@@ -312,6 +313,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
@@ -357,6 +359,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
|
||||
@@ -20,7 +20,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
|
||||
use crate::control_plane::{
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
|
||||
RoleAccessControl,
|
||||
@@ -130,6 +130,7 @@ impl MockControlPlane {
|
||||
project_id: None,
|
||||
account_id: None,
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -233,6 +234,7 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
allowed_ips: Arc::new(info.allowed_ips),
|
||||
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
|
||||
flags: info.access_blocker_flags,
|
||||
rate_limits: info.rate_limits,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use clashmap::ClashMap;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::cache::endpoints::EndpointsCache;
|
||||
@@ -22,8 +23,6 @@ use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::types::EndpointId;
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneClient {
|
||||
|
||||
@@ -227,12 +227,35 @@ pub(crate) struct UserFacingMessage {
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) role_secret: Box<str>,
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
pub(crate) block_public_connections: Option<bool>,
|
||||
pub(crate) block_vpc_connections: Option<bool>,
|
||||
|
||||
#[serde(default)]
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
pub struct EndpointRateLimitConfig {
|
||||
pub connection_attempts: ConnectionAttemptsLimit,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
pub struct ConnectionAttemptsLimit {
|
||||
pub tcp: Option<LeakyBucketSetting>,
|
||||
pub ws: Option<LeakyBucketSetting>,
|
||||
pub http: Option<LeakyBucketSetting>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize)]
|
||||
pub struct LeakyBucketSetting {
|
||||
pub rps: f64,
|
||||
pub burst: f64,
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
|
||||
@@ -11,6 +11,8 @@ pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use messages::EndpointRateLimitConfig;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
|
||||
@@ -18,8 +20,9 @@ use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::{AccountIdInt, ProjectIdInt};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{compute, scram};
|
||||
|
||||
@@ -56,6 +59,8 @@ pub(crate) struct AuthInfo {
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
/// Are public connections or VPC connections blocked?
|
||||
pub(crate) access_blocker_flags: AccessBlockerFlags,
|
||||
/// The rate limits for this endpoint.
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -101,6 +106,8 @@ pub struct EndpointAccessControl {
|
||||
pub allowed_ips: Arc<Vec<IpPattern>>,
|
||||
pub allowed_vpce: Arc<Vec<String>>,
|
||||
pub flags: AccessBlockerFlags,
|
||||
|
||||
pub rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
impl EndpointAccessControl {
|
||||
@@ -139,6 +146,36 @@ impl EndpointAccessControl {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn connection_attempt_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
rate_limiter: &EndpointRateLimiter,
|
||||
) -> Result<(), AuthError> {
|
||||
let endpoint = EndpointIdInt::from(endpoint);
|
||||
|
||||
let limits = &self.rate_limits.connection_attempts;
|
||||
let config = match ctx.protocol() {
|
||||
crate::metrics::Protocol::Http => limits.http,
|
||||
crate::metrics::Protocol::Ws => limits.ws,
|
||||
crate::metrics::Protocol::Tcp => limits.tcp,
|
||||
crate::metrics::Protocol::SniRouter => return Ok(()),
|
||||
};
|
||||
let config = config.and_then(|config| {
|
||||
if config.rps <= 0.0 || config.burst <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(LeakyBucketConfig::new(config.rps, config.burst))
|
||||
});
|
||||
|
||||
if !rate_limiter.check(endpoint, config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
|
||||
@@ -75,6 +75,7 @@
|
||||
pub mod binary;
|
||||
|
||||
mod auth;
|
||||
mod batch;
|
||||
mod cache;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use futures::FutureExt;
|
||||
use std::convert::Infallible;
|
||||
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::compute::MaybeRustlsStream;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::metrics::{
|
||||
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
|
||||
NumDbConnectionsGuard,
|
||||
};
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -64,40 +66,20 @@ pub(crate) async fn proxy_pass(
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) compute: MaybeRustlsStream,
|
||||
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
@@ -372,13 +372,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(session_id, cancel, &node.cancel_closure)
|
||||
.await;
|
||||
});
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
@@ -387,13 +393,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -69,9 +69,8 @@ pub struct LeakyBucketConfig {
|
||||
pub max: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl LeakyBucketConfig {
|
||||
pub(crate) fn new(rps: f64, max: f64) -> Self {
|
||||
pub fn new(rps: f64, max: f64) -> Self {
|
||||
assert!(rps > 0.0, "rps must be positive");
|
||||
assert!(max > 0.0, "max must be positive");
|
||||
Self { rps, max }
|
||||
|
||||
@@ -12,11 +12,10 @@ use rand::{Rng, SeedableRng};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use super::LeakyBucketConfig;
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
use super::LeakyBucketConfig;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use anyhow::Ok;
|
||||
|
||||
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
@@ -23,39 +19,12 @@ impl KeyPrefix {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result<KeyPrefix> {
|
||||
let (prefix, key_str) = key.split_once(':').ok_or_else(|| {
|
||||
anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"missing prefix"
|
||||
))
|
||||
})?;
|
||||
|
||||
match prefix {
|
||||
keyspace::CANCEL_PREFIX => {
|
||||
let id = u64::from_str_radix(key_str, 16)?;
|
||||
|
||||
Ok(KeyPrefix::Cancel(id_to_cancel_key(id)))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unknown prefix"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::pqproto::id_to_cancel_key;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -65,16 +34,4 @@ mod tests {
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_redis_key() {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
assert_eq!(ref_key, cancel_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
use redis::aio::ConnectionLike;
|
||||
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
|
||||
|
||||
@@ -35,14 +38,11 @@ impl RedisKVClient {
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.client
|
||||
.connect()
|
||||
.boxed()
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
|
||||
}
|
||||
|
||||
pub(crate) async fn query<T: FromRedisValue>(
|
||||
@@ -54,15 +54,25 @@ impl RedisKVClient {
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match q.query(&mut self.client).await {
|
||||
let e = match q.query(&mut self.client).await {
|
||||
Ok(t) => return Ok(t),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to run query: {e}");
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
tracing::error!("failed to run query: {e}");
|
||||
match e.retry_method() {
|
||||
redis::RetryMethod::Reconnect => {
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
}
|
||||
redis::RetryMethod::RetryImmediately => {}
|
||||
redis::RetryMethod::WaitAndRetry => {
|
||||
// somewhat arbitrary.
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
_ => Err(e)?,
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
Ok(q.query(&mut self.client).await?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,17 +68,20 @@ impl PoolingBackend {
|
||||
self.config.authentication_config.is_vpc_acccess_proxy,
|
||||
)?;
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let rate_limit_config = None;
|
||||
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
access_control.connection_attempt_rate_limit(
|
||||
ctx,
|
||||
&user_info.endpoint,
|
||||
&self.endpoint_rate_limiter,
|
||||
)?;
|
||||
|
||||
let role_access = backend.get_role_secret(ctx).await?;
|
||||
let Some(secret) = role_access.secret else {
|
||||
// If we don't have an authentication secret, for the http flow we can just return an error.
|
||||
info!("authentication info not found");
|
||||
return Err(AuthError::password_failed(&*user_info.user));
|
||||
};
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
ep,
|
||||
|
||||
@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -69,8 +69,10 @@ class EndpointHttpClient(requests.Session):
|
||||
json: dict[str, str] = res.json()
|
||||
return json
|
||||
|
||||
def prewarm_lfc(self):
|
||||
self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status()
|
||||
def prewarm_lfc(self, from_endpoint_id: str | None = None):
|
||||
url: str = f"http://localhost:{self.external_port}/lfc/prewarm"
|
||||
params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict()
|
||||
self.post(url, params=params).raise_for_status()
|
||||
|
||||
def prewarmed():
|
||||
json = self.prewarm_lfc_status()
|
||||
|
||||
@@ -129,6 +129,18 @@ class NeonAPI:
|
||||
|
||||
return cast("dict[str, Any]", resp.json())
|
||||
|
||||
def get_project_limits(self, project_id: str) -> dict[str, Any]:
|
||||
resp = self.__request(
|
||||
"GET",
|
||||
f"/projects/{project_id}/limits",
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
return cast("dict[str, Any]", resp.json())
|
||||
|
||||
def delete_project(
|
||||
self,
|
||||
project_id: str,
|
||||
|
||||
@@ -45,6 +45,8 @@ class NeonEndpoint:
|
||||
if self.branch.connect_env:
|
||||
self.connect_env = self.branch.connect_env.copy()
|
||||
self.connect_env["PGHOST"] = self.host
|
||||
if self.type == "read_only":
|
||||
self.project.read_only_endpoints_total += 1
|
||||
|
||||
def delete(self):
|
||||
self.project.delete_endpoint(self.id)
|
||||
@@ -228,8 +230,13 @@ class NeonProject:
|
||||
self.benchmarks: dict[str, subprocess.Popen[Any]] = {}
|
||||
self.restore_num: int = 0
|
||||
self.restart_pgbench_on_console_errors: bool = False
|
||||
self.limits: dict[str, Any] = self.get_limits()["limits"]
|
||||
self.read_only_endpoints_total: int = 0
|
||||
|
||||
def delete(self):
|
||||
def get_limits(self) -> dict[str, Any]:
|
||||
return self.neon_api.get_project_limits(self.id)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.neon_api.delete_project(self.id)
|
||||
|
||||
def create_branch(self, parent_id: str | None = None) -> NeonBranch | None:
|
||||
@@ -282,6 +289,7 @@ class NeonProject:
|
||||
self.neon_api.delete_endpoint(self.id, endpoint_id)
|
||||
self.endpoints[endpoint_id].branch.endpoints.pop(endpoint_id)
|
||||
self.endpoints.pop(endpoint_id)
|
||||
self.read_only_endpoints_total -= 1
|
||||
self.wait()
|
||||
|
||||
def start_benchmark(self, target: str, clients: int = 10) -> subprocess.Popen[Any]:
|
||||
@@ -369,49 +377,64 @@ def setup_class(
|
||||
print(f"::warning::Retried on 524 error {neon_api.retries524} times")
|
||||
if neon_api.retries4xx > 0:
|
||||
print(f"::warning::Retried on 4xx error {neon_api.retries4xx} times")
|
||||
log.info("Removing the project")
|
||||
log.info("Removing the project %s", project.id)
|
||||
project.delete()
|
||||
|
||||
|
||||
def do_action(project: NeonProject, action: str) -> None:
|
||||
def do_action(project: NeonProject, action: str) -> bool:
|
||||
"""
|
||||
Runs the action
|
||||
"""
|
||||
log.info("Action: %s", action)
|
||||
if action == "new_branch":
|
||||
log.info("Trying to create a new branch")
|
||||
if 0 <= project.limits["max_branches"] <= len(project.branches):
|
||||
log.info(
|
||||
"Maximum branch limit exceeded (%s of %s)",
|
||||
len(project.branches),
|
||||
project.limits["max_branches"],
|
||||
)
|
||||
return False
|
||||
parent = project.branches[
|
||||
random.choice(list(set(project.branches.keys()) - project.reset_branches))
|
||||
]
|
||||
log.info("Parent: %s", parent)
|
||||
child = parent.create_child_branch()
|
||||
if child is None:
|
||||
return
|
||||
return False
|
||||
log.info("Created branch %s", child)
|
||||
child.start_benchmark()
|
||||
elif action == "delete_branch":
|
||||
if project.leaf_branches:
|
||||
target = random.choice(list(project.leaf_branches.values()))
|
||||
target: NeonBranch = random.choice(list(project.leaf_branches.values()))
|
||||
log.info("Trying to delete branch %s", target)
|
||||
target.delete()
|
||||
else:
|
||||
log.info("Leaf branches not found, skipping")
|
||||
return False
|
||||
elif action == "new_ro_endpoint":
|
||||
if 0 <= project.limits["max_read_only_endpoints"] <= project.read_only_endpoints_total:
|
||||
log.info(
|
||||
"Maximum read only endpoint limit exceeded (%s of %s)",
|
||||
project.read_only_endpoints_total,
|
||||
project.limits["max_read_only_endpoints"],
|
||||
)
|
||||
return False
|
||||
ep = random.choice(
|
||||
[br for br in project.branches.values() if br.id not in project.reset_branches]
|
||||
).create_ro_endpoint()
|
||||
log.info("Created the RO endpoint with id %s branch: %s", ep.id, ep.branch.id)
|
||||
ep.start_benchmark()
|
||||
elif action == "delete_ro_endpoint":
|
||||
if project.read_only_endpoints_total == 0:
|
||||
log.info("no read_only endpoints present, skipping")
|
||||
return False
|
||||
ro_endpoints: list[NeonEndpoint] = [
|
||||
endpoint for endpoint in project.endpoints.values() if endpoint.type == "read_only"
|
||||
]
|
||||
if ro_endpoints:
|
||||
target_ep: NeonEndpoint = random.choice(ro_endpoints)
|
||||
target_ep.delete()
|
||||
log.info("endpoint %s deleted", target_ep.id)
|
||||
else:
|
||||
log.info("no read_only endpoints present, skipping")
|
||||
target_ep: NeonEndpoint = random.choice(ro_endpoints)
|
||||
target_ep.delete()
|
||||
log.info("endpoint %s deleted", target_ep.id)
|
||||
elif action == "restore_random_time":
|
||||
if project.leaf_branches:
|
||||
br: NeonBranch = random.choice(list(project.leaf_branches.values()))
|
||||
@@ -419,8 +442,10 @@ def do_action(project: NeonProject, action: str) -> None:
|
||||
br.restore_random_time()
|
||||
else:
|
||||
log.info("No leaf branches found")
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"The action {action} is unknown")
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.timeout(7200)
|
||||
@@ -457,8 +482,9 @@ def test_api_random(
|
||||
pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env)
|
||||
for _ in range(num_operations):
|
||||
log.info("Starting action #%s", _ + 1)
|
||||
do_action(
|
||||
while not do_action(
|
||||
project, random.choices([a[0] for a in ACTIONS], weights=[w[1] for w in ACTIONS])[0]
|
||||
)
|
||||
):
|
||||
log.info("Retrying...")
|
||||
project.check_all_benchmarks()
|
||||
assert True
|
||||
|
||||
@@ -188,7 +188,8 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
|
||||
pg_cur.execute("select pg_reload_conf()")
|
||||
|
||||
if query is LfcQueryMethod.COMPUTE_CTL:
|
||||
http_client.prewarm_lfc()
|
||||
# Same thing as prewarm_lfc(), testing other method
|
||||
http_client.prewarm_lfc(endpoint.endpoint_id)
|
||||
else:
|
||||
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user