Compare commits

..

3 Commits

Author SHA1 Message Date
Conrad Ludgate
4454662349 use hget instead of hgetall 2025-06-10 07:05:23 -07:00
Conrad Ludgate
2e6acf7bed box the connect future and respect the redis retry methods on err 2025-06-10 06:43:31 -07:00
Conrad Ludgate
6b528426e7 split CancelToken into RawCancelToken for smaller sizes and better typesafety 2025-06-10 06:29:43 -07:00
31 changed files with 170 additions and 375 deletions

16
Cargo.lock generated
View File

@@ -753,7 +753,6 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"headers",
"http 1.1.0",
@@ -762,8 +761,6 @@ dependencies = [
"mime",
"pin-project-lite",
"serde",
"serde_html_form",
"serde_path_to_error",
"tower 0.5.2",
"tower-layer",
"tower-service",
@@ -6425,19 +6422,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.125"

View File

@@ -71,7 +71,7 @@ aws-credential-types = "1.2.0"
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
aws-types = "1.3"
axum = { version = "0.8.1", features = ["ws"] }
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
axum-extra = { version = "0.10.0", features = ["typed-header"] }
base64 = "0.13.0"
bincode = "1.3"
bindgen = "0.71"

View File

@@ -785,7 +785,7 @@ impl ComputeNode {
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
self.prewarm_lfc(None);
self.prewarm_lfc();
}
Ok(())
}

View File

@@ -25,16 +25,11 @@ struct EndpointStoragePair {
}
const KEY: &str = "lfc_state";
impl EndpointStoragePair {
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
/// If not None, takes precedence over pspec.spec.endpoint_id
fn from_spec_and_endpoint(
pspec: &crate::compute::ParsedSpec,
endpoint_id: Option<String>,
) -> Result<Self> {
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
let Some(ref endpoint_id) = endpoint_id else {
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
@@ -89,7 +84,7 @@ impl ComputeNode {
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -102,7 +97,7 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
let Err(err) = cloned.prewarm_impl().await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
@@ -114,14 +109,13 @@ impl ComputeNode {
true
}
/// from_endpoint: None for endpoint managed by this compute_ctl
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
state.pspec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
@@ -179,7 +173,7 @@ impl ComputeNode {
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();

View File

@@ -2,7 +2,6 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
@@ -17,16 +16,8 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
Json(compute.lfc_offload_state())
}
#[derive(serde::Deserialize)]
pub struct PrewarmQuery {
pub from_endpoint: String,
}
pub(in crate::http) async fn prewarm(
compute: Compute,
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
) -> Response {
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(

View File

@@ -1,5 +1,3 @@
use std::io;
use tokio::net::TcpStream;
use crate::client::SocketConfig;
@@ -8,7 +6,7 @@ use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
config: SocketConfig,
ssl_mode: SslMode,
tls: T,
process_id: i32,
@@ -17,16 +15,6 @@ pub(crate) async fn cancel_query<T>(
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)));
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};

View File

@@ -9,9 +9,16 @@ use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub socket_config: SocketConfig,
pub raw: RawCancelToken,
}
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
@@ -36,14 +43,16 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
self.raw.ssl_mode,
tls,
self.process_id,
self.secret_key,
self.raw.process_id,
self.raw.secret_key,
)
.await
}
}
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>

View File

@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
@@ -331,10 +332,12 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
},
}
}

View File

@@ -3,7 +3,7 @@
use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::CancelToken;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;

View File

@@ -11,9 +11,6 @@
#include "utils/guc.h"
#include "utils/hsearch.h"
#if PG_MAJORVERSION_NUM > 14
#include "access/xlogrecovery.h"
#endif
typedef struct LastWrittenLsnCacheEntry
@@ -27,20 +24,14 @@ typedef struct LastWrittenLsnCacheEntry
typedef struct LwLsnCacheCtl {
int lastWrittenLsnCacheSize;
/*
* Highest (most recent) last written LSN, for pages not present in
* lastWrittenLsnCache
*/
XLogRecPtr maxLastWrittenLsnData;
* Maximal last written LSN for pages not present in lastWrittenLsnCache
*/
XLogRecPtr maxLastWrittenLsn;
/*
* Maximal last written LSN for metadata, not present in lastWrittenLsnCache
*/
XLogRecPtr maxLastWrittenLsnMetadata;
/*
* Double linked list to implement LRU replacement policy for last written LSN cache.
* Access to this list as well as to last written LSN cache is protected by 'LastWrittenLsnLock'.
*/
* Double linked list to implement LRU replacement policy for last written LSN cache.
* Access to this list as well as to last written LSN cache is protected by 'LastWrittenLsnLock'.
*/
dlist_head lastWrittenLsnLRU;
} LwLsnCacheCtl;
@@ -117,20 +108,19 @@ init_lwlsncache(void)
#else
shmemrequest();
#endif
#define SET_HOOK(name) do { \
prev_##name##_hook = name##_hook; \
name##_hook = neon_##name; \
} while (false)
SET_HOOK(set_lwlsn_block_range);
SET_HOOK(set_lwlsn_block_v);
SET_HOOK(set_lwlsn_block);
SET_HOOK(set_max_lwlsn);
SET_HOOK(set_lwlsn_relation);
SET_HOOK(set_lwlsn_db);
#undef SET_HOOK
prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook;
set_lwlsn_block_range_hook = neon_set_lwlsn_block_range;
prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook;
set_lwlsn_block_v_hook = neon_set_lwlsn_block_v;
prev_set_lwlsn_block_hook = set_lwlsn_block_hook;
set_lwlsn_block_hook = neon_set_lwlsn_block;
prev_set_max_lwlsn_hook = set_max_lwlsn_hook;
set_max_lwlsn_hook = neon_set_max_lwlsn;
prev_set_lwlsn_relation_hook = set_lwlsn_relation_hook;
set_lwlsn_relation_hook = neon_set_lwlsn_relation;
prev_set_lwlsn_db_hook = set_lwlsn_db_hook;
set_lwlsn_db_hook = neon_set_lwlsn_db;
}
@@ -149,34 +139,24 @@ static void shmemrequest(void) {
static void shmeminit(void) {
static HASHCTL info;
bool found = true;
bool found;
if (lwlsn_cache_size > 0)
{
info.keysize = sizeof(BufferTag);
info.entrysize = sizeof(LastWrittenLsnCacheEntry);
lastWrittenLsnCache = ShmemInitHash("last_written_lsn_cache",
lwlsn_cache_size, lwlsn_cache_size,
&info,
HASH_ELEM | HASH_BLOBS);
LwLsnCache = ShmemInitStruct("neon/LwLsnCacheCtl",
sizeof(LwLsnCacheCtl), &found);
}
/* initialize the shmem struct if we allocated it */
if (!found) {
XLogRecPtr redoPtr;
lwlsn_cache_size, lwlsn_cache_size,
&info,
HASH_ELEM | HASH_BLOBS);
LwLsnCache = ShmemInitStruct("neon/LwLsnCacheCtl", sizeof(LwLsnCacheCtl), &found);
// Now set the size in the struct
LwLsnCache->lastWrittenLsnCacheSize = lwlsn_cache_size;
dlist_init(&LwLsnCache->lastWrittenLsnLRU);
redoPtr = GetRedoRecPtr();
LwLsnCache->maxLastWrittenLsnMetadata = redoPtr;
LwLsnCache->maxLastWrittenLsnData = redoPtr;
if (found) {
return;
}
}
dlist_init(&LwLsnCache->lastWrittenLsnLRU);
LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr();
if (prev_shmem_startup_hook) {
prev_shmem_startup_hook();
}
@@ -200,18 +180,17 @@ neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno)
LWLockAcquire(LastWrittenLsnLock, LW_SHARED);
if (NInfoGetRelNumber(rlocator) != InvalidOid) /* data page*/
/* Maximal last written LSN among all non-cached pages */
lsn = LwLsnCache->maxLastWrittenLsn;
if (NInfoGetRelNumber(rlocator) != InvalidOid)
{
BufferTag key;
Oid spcOid = NInfoGetSpcOid(rlocator);
Oid dbOid = NInfoGetDbOid(rlocator);
Oid relNumber = NInfoGetRelNumber(rlocator);
BufTagInit(key, relNumber, forknum, blkno, spcOid, dbOid);
/* Maximal last written LSN among all non-cached data pages */
lsn = LwLsnCache->maxLastWrittenLsnData;
entry = hash_search(lastWrittenLsnCache, &key, HASH_FIND, NULL);
if (entry != NULL)
lsn = entry->lsn;
@@ -233,13 +212,9 @@ neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno)
lsn = SetLastWrittenLSNForBlockRangeInternal(lsn, rlocator, forknum, blkno, 1);
}
}
else /* metadata */
else
{
HASH_SEQ_STATUS seq;
/* Maximal last written LSN for metadata */
lsn = Max(LwLsnCache->maxLastWrittenLsnMetadata,
LwLsnCache->maxLastWrittenLsnData);
/* Find maximum of all cached LSNs */
hash_seq_init(&seq, lastWrittenLsnCache);
while ((entry = (LastWrittenLsnCacheEntry *) hash_seq_search(&seq)) != NULL)
@@ -255,8 +230,7 @@ neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno)
static void neon_set_max_lwlsn(XLogRecPtr lsn) {
LWLockAcquire(LastWrittenLsnLock, LW_EXCLUSIVE);
LwLsnCache->maxLastWrittenLsnMetadata = lsn;
LwLsnCache->maxLastWrittenLsnData = lsn;
LwLsnCache->maxLastWrittenLsn = lsn;
LWLockRelease(LastWrittenLsnLock);
}
@@ -317,7 +291,7 @@ neon_get_lwlsn_v(NRelFileInfo relfilenode, ForkNumber forknum,
LWLockRelease(LastWrittenLsnLock);
LWLockAcquire(LastWrittenLsnLock, LW_EXCLUSIVE);
lsn = LwLsnCache->maxLastWrittenLsnData;
lsn = LwLsnCache->maxLastWrittenLsn;
for (int i = 0; i < nblocks; i++)
{
@@ -332,8 +306,7 @@ neon_get_lwlsn_v(NRelFileInfo relfilenode, ForkNumber forknum,
else
{
HASH_SEQ_STATUS seq;
Assert(nblocks == 1);
lsn = LwLsnCache->maxLastWrittenLsnMetadata;
lsn = LwLsnCache->maxLastWrittenLsn;
/* Find maximum of all cached LSNs */
hash_seq_init(&seq, lastWrittenLsnCache);
while ((entry = (LastWrittenLsnCacheEntry *) hash_seq_search(&seq)) != NULL)
@@ -361,10 +334,10 @@ SetLastWrittenLSNForBlockRangeInternal(XLogRecPtr lsn,
{
if (NInfoGetRelNumber(rlocator) == InvalidOid)
{
if (lsn > LwLsnCache->maxLastWrittenLsnMetadata)
LwLsnCache->maxLastWrittenLsnMetadata = lsn;
if (lsn > LwLsnCache->maxLastWrittenLsn)
LwLsnCache->maxLastWrittenLsn = lsn;
else
lsn = LwLsnCache->maxLastWrittenLsnMetadata;
lsn = LwLsnCache->maxLastWrittenLsn;
}
else
{
@@ -396,19 +369,10 @@ SetLastWrittenLSNForBlockRangeInternal(XLogRecPtr lsn,
if (hash_get_num_entries(lastWrittenLsnCache) > LwLsnCache->lastWrittenLsnCacheSize)
{
/* Replace least recently used entry */
LastWrittenLsnCacheEntry* victim = NULL;
victim = dlist_container(LastWrittenLsnCacheEntry, lru_node, dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
while (!XLogRecordReplayFinished(victim->lsn))
{
/* in recovery, we don't allow eviction of entries with the LSN of a record that has yet to be returned */
dlist_push_tail(&LwLsnCache->lastWrittenLsnLRU, &entry->lru_node);
victim = dlist_container(LastWrittenLsnCacheEntry, lru_node, dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
}
LastWrittenLsnCacheEntry* victim = dlist_container(LastWrittenLsnCacheEntry, lru_node, dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
/* Adjust max LSN for not cached relations/chunks if needed */
if (victim->lsn > LwLsnCache->maxLastWrittenLsnMetadata)
LwLsnCache->maxLastWrittenLsnMetadata = victim->lsn;
if (victim->lsn > LwLsnCache->maxLastWrittenLsn)
LwLsnCache->maxLastWrittenLsn = victim->lsn;
hash_search(lastWrittenLsnCache, victim, HASH_REMOVE, NULL);
}
@@ -469,13 +433,6 @@ neon_set_lwlsn_block_v(const XLogRecPtr *lsns, NRelFileInfo relfilenode,
Oid dbOid = NInfoGetDbOid(relfilenode);
Oid relNumber = NInfoGetRelNumber(relfilenode);
/*
* We ignore the operation when the input is invalid:
* - we must have gotten LSNs to set
* - we must have pages to write
* - the cache must be enabled
* - we must be processing a data page, not a metadata request
*/
if (lsns == NULL || nblocks == 0 || LwLsnCache->lastWrittenLsnCacheSize == 0 ||
NInfoGetRelNumber(relfilenode) == InvalidOid)
return InvalidXLogRecPtr;
@@ -509,25 +466,10 @@ neon_set_lwlsn_block_v(const XLogRecPtr *lsns, NRelFileInfo relfilenode,
if (hash_get_num_entries(lastWrittenLsnCache) > LwLsnCache->lastWrittenLsnCacheSize)
{
/* Replace least recently used entry */
LastWrittenLsnCacheEntry* victim = dlist_container(LastWrittenLsnCacheEntry, lru_node,
dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
/*
* If replay is still working on this LSN, we can't evict the
* page. Therefore, we must find a different victim, and return
* the one we just found to the pool.
*/
while (!XLogRecordReplayFinished(victim->lsn))
{
dlist_push_tail(&LwLsnCache->lastWrittenLsnLRU,
&entry->lru_node);
victim = dlist_container(LastWrittenLsnCacheEntry, lru_node,
dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
}
LastWrittenLsnCacheEntry* victim = dlist_container(LastWrittenLsnCacheEntry, lru_node, dlist_pop_head_node(&LwLsnCache->lastWrittenLsnLRU));
/* Adjust max LSN for not cached relations/chunks if needed */
if (victim->lsn > LwLsnCache->maxLastWrittenLsnData)
LwLsnCache->maxLastWrittenLsnData = victim->lsn;
if (victim->lsn > LwLsnCache->maxLastWrittenLsn)
LwLsnCache->maxLastWrittenLsn = victim->lsn;
hash_search(lastWrittenLsnCache, victim, HASH_REMOVE, NULL);
}

View File

@@ -14,13 +14,12 @@ use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
@@ -231,8 +230,11 @@ async fn auth_quirks(
config.is_vpc_acccess_proxy,
)?;
access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
@@ -399,7 +401,6 @@ impl Backend<'_, ComputeUserInfo> {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
}),
}
}
@@ -438,7 +439,6 @@ mod tests {
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
};
@@ -477,7 +477,6 @@ mod tests {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
rate_limits: EndpointRateLimitConfig::default(),
})
}

View File

@@ -364,7 +364,6 @@ mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -400,7 +399,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret1.clone(),
@@ -416,7 +414,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret2.clone(),
@@ -442,7 +439,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret3.clone(),

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
@@ -33,7 +33,6 @@ const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
pub enum CancelKeyOp {
StoreCancelKey {
key: String,
field: String,
value: String,
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
_guard: CancelChannelSizeGuard<'static>,
@@ -41,7 +40,7 @@ pub enum CancelKeyOp {
},
GetCancelData {
key: String,
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
resp_tx: oneshot::Sender<anyhow::Result<String>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -120,7 +119,6 @@ impl CancelKeyOp {
match self {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
@@ -128,7 +126,7 @@ impl CancelKeyOp {
} => {
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hset(&key, field, value), reply);
pipe.add_command(Cmd::hset(&key, "data", value), reply);
pipe.add_command_no_reply(Cmd::expire(key, expire));
}
CancelKeyOp::GetCancelData {
@@ -137,7 +135,7 @@ impl CancelKeyOp {
_guard,
} => {
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
pipe.add_command_with_reply(Cmd::hget(key, "data"), reply);
}
CancelKeyOp::RemoveCancelKey {
key,
@@ -160,7 +158,7 @@ pub enum CancelReplyOp {
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
resp_tx: oneshot::Sender<anyhow::Result<String>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -347,7 +345,7 @@ impl CancellationHandler {
_guard: Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGetAll),
.guard(RedisMsgKind::HGet),
};
let Some(tx) = &self.tx else {
@@ -366,32 +364,21 @@ impl CancellationHandler {
CancelError::InternalError
})?;
let cancel_state_str: Option<String> = match result {
Ok(mut state) => {
if state.len() == 1 {
Some(state.remove(0).1)
} else {
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
return Err(CancelError::InternalError);
}
}
let cancel_state_str: String = match result {
Ok(s) => s,
Err(e) => {
tracing::warn!("failed to receive cancel state from redis: {e}");
return Err(CancelError::InternalError);
}
};
let cancel_state: Option<CancelClosure> = match cancel_state_str {
Some(state) => {
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Some(cancel_closure)
}
None => None,
};
Ok(cancel_state)
let cancel_closure: CancelClosure =
serde_json::from_str(&cancel_state_str).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Ok(Some(cancel_closure))
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
@@ -470,7 +457,7 @@ impl CancellationHandler {
#[derive(Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}
@@ -478,7 +465,7 @@ pub struct CancelClosure {
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
@@ -538,7 +525,6 @@ impl Session {
let op = CancelKeyOp::StoreCancelKey {
key: self.redis_key.clone(),
field: "data".to_string(),
value: closure_json,
resp_tx: None,
_guard: Metrics::get()

View File

@@ -9,7 +9,7 @@ use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_client::{NoTls, RawCancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
@@ -327,8 +327,7 @@ impl ConnectInfo {
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
RawCancelToken {
ssl_mode: self.ssl_mode,
process_id,
secret_key,

View File

@@ -146,7 +146,6 @@ impl NeonControlPlaneClient {
public_access_blocked: block_public_connections,
vpc_access_blocked: block_vpc_connections,
},
rate_limits: body.rate_limits,
})
}
.inspect_err(|e| tracing::debug!(error = ?e))
@@ -313,7 +312,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
@@ -359,7 +357,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,

View File

@@ -20,7 +20,7 @@ use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
@@ -130,7 +130,6 @@ impl MockControlPlane {
project_id: None,
account_id: None,
access_blocker_flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
})
}
@@ -234,7 +233,6 @@ impl super::ControlPlaneApi for MockControlPlane {
allowed_ips: Arc::new(info.allowed_ips),
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
flags: info.access_blocker_flags,
rate_limits: info.rate_limits,
})
}

View File

@@ -10,7 +10,6 @@ use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::endpoints::EndpointsCache;
@@ -23,6 +22,8 @@ use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::types::EndpointId;
use super::{EndpointAccessControl, RoleAccessControl};
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {

View File

@@ -227,35 +227,12 @@ pub(crate) struct UserFacingMessage {
#[derive(Deserialize)]
pub(crate) struct GetEndpointAccessControl {
pub(crate) role_secret: Box<str>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
#[serde(default)]
pub(crate) rate_limits: EndpointRateLimitConfig,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct EndpointRateLimitConfig {
pub connection_attempts: ConnectionAttemptsLimit,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct ConnectionAttemptsLimit {
pub tcp: Option<LeakyBucketSetting>,
pub ws: Option<LeakyBucketSetting>,
pub http: Option<LeakyBucketSetting>,
}
#[derive(Copy, Clone, Deserialize)]
pub struct LeakyBucketSetting {
pub rps: f64,
pub burst: f64,
}
/// Response which holds compute node's `host:port` pair.

View File

@@ -11,8 +11,6 @@ pub(crate) mod errors;
use std::sync::Arc;
use messages::EndpointRateLimitConfig;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
@@ -20,9 +18,8 @@ use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::intern::{AccountIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, scram};
@@ -59,8 +56,6 @@ pub(crate) struct AuthInfo {
pub(crate) account_id: Option<AccountIdInt>,
/// Are public connections or VPC connections blocked?
pub(crate) access_blocker_flags: AccessBlockerFlags,
/// The rate limits for this endpoint.
pub(crate) rate_limits: EndpointRateLimitConfig,
}
/// Info for establishing a connection to a compute node.
@@ -106,8 +101,6 @@ pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,
pub flags: AccessBlockerFlags,
pub rate_limits: EndpointRateLimitConfig,
}
impl EndpointAccessControl {
@@ -146,36 +139,6 @@ impl EndpointAccessControl {
Ok(())
}
pub fn connection_attempt_rate_limit(
&self,
ctx: &RequestContext,
endpoint: &EndpointId,
rate_limiter: &EndpointRateLimiter,
) -> Result<(), AuthError> {
let endpoint = EndpointIdInt::from(endpoint);
let limits = &self.rate_limits.connection_attempts;
let config = match ctx.protocol() {
crate::metrics::Protocol::Http => limits.http,
crate::metrics::Protocol::Ws => limits.ws,
crate::metrics::Protocol::Tcp => limits.tcp,
crate::metrics::Protocol::SniRouter => return Ok(()),
};
let config = config.and_then(|config| {
if config.rps <= 0.0 || config.burst <= 0.0 {
return None;
}
Some(LeakyBucketConfig::new(config.rps, config.burst))
});
if !rate_limiter.check(endpoint, config, 1) {
return Err(AuthError::too_many_connections());
}
Ok(())
}
}
/// This will allocate per each call, but the http requests alone

View File

@@ -69,8 +69,9 @@ pub struct LeakyBucketConfig {
pub max: f64,
}
#[cfg(test)]
impl LeakyBucketConfig {
pub fn new(rps: f64, max: f64) -> Self {
pub(crate) fn new(rps: f64, max: f64) -> Self {
assert!(rps > 0.0, "rps must be positive");
assert!(max > 0.0, "max must be positive");
Self { rps, max }

View File

@@ -12,10 +12,11 @@ use rand::{Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
use super::LeakyBucketConfig;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
use super::LeakyBucketConfig;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,

View File

@@ -1,3 +1,6 @@
use std::time::Duration;
use futures::FutureExt;
use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
@@ -35,14 +38,11 @@ impl RedisKVClient {
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
self.client
.connect()
.boxed()
.await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
}
pub(crate) async fn query<T: FromRedisValue>(
@@ -54,15 +54,25 @@ impl RedisKVClient {
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match q.query(&mut self.client).await {
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {
tracing::error!("failed to run query: {e}");
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
}
redis::RetryMethod::RetryImmediately => {}
redis::RetryMethod::WaitAndRetry => {
// somewhat arbitrary.
tokio::time::sleep(Duration::from_millis(100)).await;
}
_ => Err(e)?,
}
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
Ok(q.query(&mut self.client).await?)
}
}

View File

@@ -68,20 +68,17 @@ impl PoolingBackend {
self.config.authentication_config.is_vpc_acccess_proxy,
)?;
access_control.connection_attempt_rate_limit(
ctx,
&user_info.endpoint,
&self.endpoint_rate_limiter,
)?;
let ep = EndpointIdInt::from(&user_info.endpoint);
let rate_limit_config = None;
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = backend.get_role_secret(ctx).await?;
let Some(secret) = role_access.secret else {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,

View File

@@ -69,10 +69,8 @@ class EndpointHttpClient(requests.Session):
json: dict[str, str] = res.json()
return json
def prewarm_lfc(self, from_endpoint_id: str | None = None):
url: str = f"http://localhost:{self.external_port}/lfc/prewarm"
params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict()
self.post(url, params=params).raise_for_status()
def prewarm_lfc(self):
self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status()
def prewarmed():
json = self.prewarm_lfc_status()

View File

@@ -129,18 +129,6 @@ class NeonAPI:
return cast("dict[str, Any]", resp.json())
def get_project_limits(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/limits",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def delete_project(
self,
project_id: str,

View File

@@ -45,8 +45,6 @@ class NeonEndpoint:
if self.branch.connect_env:
self.connect_env = self.branch.connect_env.copy()
self.connect_env["PGHOST"] = self.host
if self.type == "read_only":
self.project.read_only_endpoints_total += 1
def delete(self):
self.project.delete_endpoint(self.id)
@@ -230,13 +228,8 @@ class NeonProject:
self.benchmarks: dict[str, subprocess.Popen[Any]] = {}
self.restore_num: int = 0
self.restart_pgbench_on_console_errors: bool = False
self.limits: dict[str, Any] = self.get_limits()["limits"]
self.read_only_endpoints_total: int = 0
def get_limits(self) -> dict[str, Any]:
return self.neon_api.get_project_limits(self.id)
def delete(self) -> None:
def delete(self):
self.neon_api.delete_project(self.id)
def create_branch(self, parent_id: str | None = None) -> NeonBranch | None:
@@ -289,7 +282,6 @@ class NeonProject:
self.neon_api.delete_endpoint(self.id, endpoint_id)
self.endpoints[endpoint_id].branch.endpoints.pop(endpoint_id)
self.endpoints.pop(endpoint_id)
self.read_only_endpoints_total -= 1
self.wait()
def start_benchmark(self, target: str, clients: int = 10) -> subprocess.Popen[Any]:
@@ -377,64 +369,49 @@ def setup_class(
print(f"::warning::Retried on 524 error {neon_api.retries524} times")
if neon_api.retries4xx > 0:
print(f"::warning::Retried on 4xx error {neon_api.retries4xx} times")
log.info("Removing the project %s", project.id)
log.info("Removing the project")
project.delete()
def do_action(project: NeonProject, action: str) -> bool:
def do_action(project: NeonProject, action: str) -> None:
"""
Runs the action
"""
log.info("Action: %s", action)
if action == "new_branch":
log.info("Trying to create a new branch")
if 0 <= project.limits["max_branches"] <= len(project.branches):
log.info(
"Maximum branch limit exceeded (%s of %s)",
len(project.branches),
project.limits["max_branches"],
)
return False
parent = project.branches[
random.choice(list(set(project.branches.keys()) - project.reset_branches))
]
log.info("Parent: %s", parent)
child = parent.create_child_branch()
if child is None:
return False
return
log.info("Created branch %s", child)
child.start_benchmark()
elif action == "delete_branch":
if project.leaf_branches:
target: NeonBranch = random.choice(list(project.leaf_branches.values()))
target = random.choice(list(project.leaf_branches.values()))
log.info("Trying to delete branch %s", target)
target.delete()
else:
log.info("Leaf branches not found, skipping")
return False
elif action == "new_ro_endpoint":
if 0 <= project.limits["max_read_only_endpoints"] <= project.read_only_endpoints_total:
log.info(
"Maximum read only endpoint limit exceeded (%s of %s)",
project.read_only_endpoints_total,
project.limits["max_read_only_endpoints"],
)
return False
ep = random.choice(
[br for br in project.branches.values() if br.id not in project.reset_branches]
).create_ro_endpoint()
log.info("Created the RO endpoint with id %s branch: %s", ep.id, ep.branch.id)
ep.start_benchmark()
elif action == "delete_ro_endpoint":
if project.read_only_endpoints_total == 0:
log.info("no read_only endpoints present, skipping")
return False
ro_endpoints: list[NeonEndpoint] = [
endpoint for endpoint in project.endpoints.values() if endpoint.type == "read_only"
]
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
if ro_endpoints:
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
else:
log.info("no read_only endpoints present, skipping")
elif action == "restore_random_time":
if project.leaf_branches:
br: NeonBranch = random.choice(list(project.leaf_branches.values()))
@@ -442,10 +419,8 @@ def do_action(project: NeonProject, action: str) -> bool:
br.restore_random_time()
else:
log.info("No leaf branches found")
return False
else:
raise ValueError(f"The action {action} is unknown")
return True
@pytest.mark.timeout(7200)
@@ -482,9 +457,8 @@ def test_api_random(
pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env)
for _ in range(num_operations):
log.info("Starting action #%s", _ + 1)
while not do_action(
do_action(
project, random.choices([a[0] for a in ACTIONS], weights=[w[1] for w in ACTIONS])[0]
):
log.info("Retrying...")
)
project.check_all_benchmarks()
assert True

View File

@@ -188,8 +188,7 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
pg_cur.execute("select pg_reload_conf()")
if query is LfcQueryMethod.COMPUTE_CTL:
# Same thing as prewarm_lfc(), testing other method
http_client.prewarm_lfc(endpoint.endpoint_id)
http_client.prewarm_lfc()
else:
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))