mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-05 19:50:36 +00:00
Compare commits
8 Commits
release-co
...
conrad/mem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfb2e3c178 | ||
|
|
0a34084ba5 | ||
|
|
b33047df7e | ||
|
|
1c5477619f | ||
|
|
40f5b3e8df | ||
|
|
791b5d736b | ||
|
|
96bcfba79e | ||
|
|
8e95455aef |
@@ -30,6 +30,7 @@ workspace-members = [
|
||||
"vm_monitor",
|
||||
# All of these exist in libs and are not usually built independently.
|
||||
# Putting workspace hack there adds a bottleneck for cargo builds.
|
||||
"alloc-metrics",
|
||||
"compute_api",
|
||||
"consumption_metrics",
|
||||
"desim",
|
||||
|
||||
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -61,6 +61,17 @@ dependencies = [
|
||||
"equator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alloc-metrics"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"measured",
|
||||
"metrics",
|
||||
"thread_local",
|
||||
"tikv-jemallocator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.16"
|
||||
@@ -5301,6 +5312,7 @@ name = "proxy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"alloc-metrics",
|
||||
"anyhow",
|
||||
"arc-swap",
|
||||
"assert-json-diff",
|
||||
@@ -7332,12 +7344,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
|
||||
version = "1.1.9"
|
||||
source = "git+https://github.com/conradludgate/thread_local-rs?branch=no-tls-destructor-get#f9ca3d375745c14a632ae3ffe6a7a646dc8421a0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -195,6 +195,7 @@ sync_wrapper = "0.1.2"
|
||||
tar = "0.4"
|
||||
test-context = "0.3"
|
||||
thiserror = "1.0"
|
||||
thread_local = "1.1.9"
|
||||
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
|
||||
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
|
||||
tokio = { version = "1.43.1", features = ["macros"] }
|
||||
@@ -253,6 +254,7 @@ azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git"
|
||||
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
|
||||
|
||||
## Local libraries
|
||||
alloc-metrics = { version = "0.1", path = "./libs/alloc-metrics/" }
|
||||
compute_api = { version = "0.1", path = "./libs/compute_api/" }
|
||||
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
|
||||
desim = { version = "0.1", path = "./libs/desim" }
|
||||
@@ -302,6 +304,9 @@ tonic-build = "0.13.1"
|
||||
# Needed to get `tokio-postgres-rustls` to depend on our fork.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
|
||||
|
||||
# Needed to fix a bug in alloc-metrics
|
||||
thread_local = { git = "https://github.com/conradludgate/thread_local-rs", branch = "no-tls-destructor-get" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
|
||||
|
||||
## Example: Start with Postgres 16
|
||||
|
||||
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
|
||||
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
|
||||
|
||||
```shell
|
||||
cargo neon init --pg-version 16
|
||||
cargo neon init
|
||||
cargo neon start
|
||||
cargo neon tenant create --set-default --pg-version 16
|
||||
cargo neon endpoint create main --pg-version 16
|
||||
|
||||
18
libs/alloc-metrics/Cargo.toml
Normal file
18
libs/alloc-metrics/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "alloc-metrics"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
metrics.workspace = true
|
||||
measured.workspace = true
|
||||
thread_local.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
criterion.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "alloc"
|
||||
110
libs/alloc-metrics/benches/alloc.rs
Normal file
110
libs/alloc-metrics/benches/alloc.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::alloc::{GlobalAlloc, Layout, System, handle_alloc_error};
|
||||
|
||||
use alloc_metrics::TrackedAllocator;
|
||||
use criterion::{
|
||||
AxisScale, BenchmarkGroup, BenchmarkId, Criterion, PlotConfiguration, measurement::Measurement,
|
||||
};
|
||||
use measured::FixedCardinalityLabel;
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
||||
fn main() {
|
||||
let mut c = Criterion::default().configure_from_args();
|
||||
bench(&mut c);
|
||||
c.final_summary();
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn bench(c: &mut Criterion) {
|
||||
bench_alloc(c.benchmark_group("alloc/system"), &System, &ALLOC_SYSTEM);
|
||||
bench_alloc(c.benchmark_group("alloc/jemalloc"), &Jemalloc, &ALLOC_JEMALLOC);
|
||||
|
||||
bench_dealloc(c.benchmark_group("dealloc/system"), &System, &ALLOC_SYSTEM);
|
||||
bench_dealloc(c.benchmark_group("dealloc/jemalloc"), &Jemalloc, &ALLOC_JEMALLOC);
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
#[label(singleton = "memory_context")]
|
||||
pub enum MemoryContext {
|
||||
Root,
|
||||
Test,
|
||||
}
|
||||
|
||||
static ALLOC_SYSTEM: TrackedAllocator<System, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
|
||||
static ALLOC_JEMALLOC: TrackedAllocator<Jemalloc, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) };
|
||||
|
||||
const KB: u64 = 1024;
|
||||
const SIZES: [u64; 6] = [64, 256, KB, 4 * KB, 16 * KB, KB * KB];
|
||||
|
||||
fn bench_alloc<A: GlobalAlloc>(
|
||||
mut g: BenchmarkGroup<'_, impl Measurement>,
|
||||
alloc1: &'static A,
|
||||
alloc2: &'static TrackedAllocator<A, MemoryContext>,
|
||||
) {
|
||||
g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
|
||||
for size in SIZES {
|
||||
let layout = Layout::from_size_align(size as usize, 8).unwrap();
|
||||
|
||||
g.throughput(criterion::Throughput::Bytes(size));
|
||||
g.bench_with_input(BenchmarkId::new("default", size), &layout, |b, &layout| {
|
||||
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
|
||||
b.iter_batched(|| {}, |()| Alloc::new(alloc1, layout), bs);
|
||||
});
|
||||
g.bench_with_input(BenchmarkId::new("tracked", size), &layout, |b, &layout| {
|
||||
let _scope = alloc2.scope(MemoryContext::Test);
|
||||
|
||||
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
|
||||
b.iter_batched(|| {}, |()| Alloc::new(alloc2, layout), bs);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_dealloc<A: GlobalAlloc>(
|
||||
mut g: BenchmarkGroup<'_, impl Measurement>,
|
||||
alloc1: &'static A,
|
||||
alloc2: &'static TrackedAllocator<A, MemoryContext>,
|
||||
) {
|
||||
g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
|
||||
for size in SIZES {
|
||||
let layout = Layout::from_size_align(size as usize, 8).unwrap();
|
||||
|
||||
g.throughput(criterion::Throughput::Bytes(size));
|
||||
g.bench_with_input(BenchmarkId::new("default", size), &layout, |b, &layout| {
|
||||
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
|
||||
b.iter_batched(|| Alloc::new(alloc1, layout), drop, bs);
|
||||
});
|
||||
g.bench_with_input(BenchmarkId::new("tracked", size), &layout, |b, &layout| {
|
||||
let _scope = alloc2.scope(MemoryContext::Test);
|
||||
|
||||
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
|
||||
b.iter_batched(|| Alloc::new(alloc2, layout), drop, bs);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
struct Alloc<'a, A: GlobalAlloc> {
|
||||
alloc: &'a A,
|
||||
ptr: *mut u8,
|
||||
layout: Layout,
|
||||
}
|
||||
|
||||
impl<'a, A: GlobalAlloc> Alloc<'a, A> {
|
||||
fn new(alloc: &'a A, layout: Layout) -> Self {
|
||||
let ptr = unsafe { alloc.alloc(layout) };
|
||||
if ptr.is_null() {
|
||||
handle_alloc_error(layout);
|
||||
}
|
||||
|
||||
// actually make the page resident.
|
||||
unsafe { ptr.cast::<u8>().write(1) };
|
||||
|
||||
Self { alloc, ptr, layout }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, A: GlobalAlloc> Drop for Alloc<'a, A> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { self.alloc.dealloc(self.ptr, self.layout) };
|
||||
}
|
||||
}
|
||||
48
libs/alloc-metrics/src/counters.rs
Normal file
48
libs/alloc-metrics/src/counters.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use measured::{
|
||||
FixedCardinalityLabel, LabelGroup, label::StaticLabelSet, metric::MetricFamilyEncoding,
|
||||
};
|
||||
use metrics::{CounterPairAssoc, Dec, Inc, MeasuredCounterPairState};
|
||||
|
||||
use crate::metric_vec::DenseMetricVec;
|
||||
|
||||
pub struct DenseCounterPairVec<
|
||||
A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>,
|
||||
L: FixedCardinalityLabel + LabelGroup,
|
||||
> {
|
||||
pub vec: DenseMetricVec<MeasuredCounterPairState, L>,
|
||||
pub _marker: PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>, L: FixedCardinalityLabel + LabelGroup>
|
||||
DenseCounterPairVec<A, L>
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
vec: DenseMetricVec::new(),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, A, L> ::measured::metric::group::MetricGroup<T> for DenseCounterPairVec<A, L>
|
||||
where
|
||||
T: ::measured::metric::group::Encoding,
|
||||
::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding<T>,
|
||||
A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>,
|
||||
L: FixedCardinalityLabel + LabelGroup,
|
||||
{
|
||||
fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> {
|
||||
// write decrement first to avoid a race condition where inc - dec < 0
|
||||
T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?;
|
||||
self.vec
|
||||
.collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?;
|
||||
|
||||
T::write_help(enc, A::INC_NAME, A::INC_HELP)?;
|
||||
self.vec
|
||||
.collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
441
libs/alloc-metrics/src/lib.rs
Normal file
441
libs/alloc-metrics/src/lib.rs
Normal file
@@ -0,0 +1,441 @@
|
||||
//! Tagged allocator measurements.
|
||||
|
||||
mod counters;
|
||||
mod metric_vec;
|
||||
|
||||
use std::{
|
||||
alloc::{GlobalAlloc, Layout},
|
||||
cell::Cell,
|
||||
marker::PhantomData,
|
||||
sync::{
|
||||
OnceLock,
|
||||
atomic::{AtomicU64, Ordering::Relaxed},
|
||||
},
|
||||
};
|
||||
|
||||
use measured::{
|
||||
FixedCardinalityLabel, LabelGroup, MetricGroup,
|
||||
label::StaticLabelSet,
|
||||
metric::{MetricEncoding, counter::CounterState, group::Encoding, name::MetricName},
|
||||
};
|
||||
use metrics::{CounterPairAssoc, MeasuredCounterPairState};
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
type AllocCounter<T> = counters::DenseCounterPairVec<AllocPair<T>, T>;
|
||||
|
||||
pub struct TrackedAllocator<A, T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup> {
|
||||
inner: A,
|
||||
|
||||
/// potentially high-content fallback if the thread was not registered.
|
||||
default_counters: MeasuredCounterPairState,
|
||||
/// Default tag to use if this thread is not registered.
|
||||
default_tag: T,
|
||||
|
||||
thread: OnceLock<RegisteredThread<T>>,
|
||||
|
||||
/// where thread alloc data is eventually saved to, even if threads are shutdown.
|
||||
global: OnceLock<AllocCounter<T>>,
|
||||
}
|
||||
|
||||
impl<A, T> TrackedAllocator<A, T>
|
||||
where
|
||||
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
|
||||
{
|
||||
/// # Safety
|
||||
///
|
||||
/// [`FixedCardinalityLabel`] must be implemented correctly, fully dense, and must not panic.
|
||||
pub const unsafe fn new(alloc: A, default: T) -> Self {
|
||||
TrackedAllocator {
|
||||
inner: alloc,
|
||||
default_tag: default,
|
||||
default_counters: MeasuredCounterPairState {
|
||||
inc: CounterState {
|
||||
count: AtomicU64::new(0),
|
||||
},
|
||||
dec: CounterState {
|
||||
count: AtomicU64::new(0),
|
||||
},
|
||||
},
|
||||
thread: OnceLock::new(),
|
||||
global: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocations
|
||||
pub fn register_thread(&'static self) {
|
||||
self.register_thread_inner();
|
||||
}
|
||||
|
||||
pub fn scope(&'static self, tag: T) -> AllocScope<'static, T> {
|
||||
let cell = self.register_thread_inner();
|
||||
let last = cell.replace(tag);
|
||||
AllocScope { cell, last }
|
||||
}
|
||||
|
||||
fn register_thread_inner(&'static self) -> &'static Cell<T> {
|
||||
let thread = self.thread.get_or_init(|| RegisteredThread {
|
||||
scope: ThreadLocal::new(),
|
||||
state: ThreadLocal::new(),
|
||||
});
|
||||
|
||||
thread.state.get_or(|| ThreadState {
|
||||
counters: AllocCounter::new(),
|
||||
global: self.global.get_or_init(AllocCounter::new),
|
||||
});
|
||||
|
||||
thread.scope.get_or(|| Cell::new(self.default_tag))
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! alloc {
|
||||
($alloc_fn:ident) => {
|
||||
unsafe fn $alloc_fn(&self, layout: Layout) -> *mut u8 {
|
||||
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
|
||||
return std::ptr::null_mut();
|
||||
};
|
||||
let tagged_layout = tagged_layout.pad_to_align();
|
||||
|
||||
// Safety: The layout is not zero-sized.
|
||||
let ptr = unsafe { self.inner.$alloc_fn(tagged_layout) };
|
||||
|
||||
// allocation failed.
|
||||
if ptr.is_null() {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// We are being very careful here to not allocate or panic.
|
||||
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
|
||||
let tag = thread.and_then(|t| t.0).map_or(self.default_tag, Cell::get);
|
||||
|
||||
// Allocation successful. Write our tag
|
||||
// Safety: tag_offset is inbounds of the ptr
|
||||
unsafe { ptr.add(tag_offset).cast::<T>().write(tag) }
|
||||
|
||||
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
|
||||
let metric = if let Some(counters) = counters {
|
||||
counters.vec.get_metric(tag)
|
||||
} else {
|
||||
// if tag is not default, then the thread state would have been registered, therefore tag must be default.
|
||||
&self.default_counters
|
||||
};
|
||||
|
||||
metric.inc.count.fetch_add(layout.size() as u64, Relaxed);
|
||||
|
||||
ptr
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// We will tag our allocation by adding `T` to the end of the layout.
|
||||
// This is ok only as long as it does not overflow. If it does, we will
|
||||
// just fail the allocation by returning null.
|
||||
//
|
||||
// Safety: we will not unwind during alloc, and we will ensure layouts are handled correctly.
|
||||
unsafe impl<A, T> GlobalAlloc for TrackedAllocator<A, T>
|
||||
where
|
||||
A: GlobalAlloc,
|
||||
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
|
||||
{
|
||||
alloc!(alloc);
|
||||
alloc!(alloc_zeroed);
|
||||
|
||||
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
|
||||
// SAFETY: the caller must ensure that the `new_size` does not overflow.
|
||||
// `layout.align()` comes from a `Layout` and is thus guaranteed to be valid.
|
||||
let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
|
||||
|
||||
let Ok((new_tagged_layout, new_tag_offset)) = new_layout.extend(Layout::new::<T>()) else {
|
||||
return std::ptr::null_mut();
|
||||
};
|
||||
let new_tagged_layout = new_tagged_layout.pad_to_align();
|
||||
|
||||
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
|
||||
// Safety: This layout clearly did not match what was originally allocated,
|
||||
// otherwise alloc() would have caught this error and returned null.
|
||||
unsafe { std::hint::unreachable_unchecked() }
|
||||
};
|
||||
let tagged_layout = tagged_layout.pad_to_align();
|
||||
|
||||
// get the tag set during alloc
|
||||
// Safety: tag_offset is inbounds of the ptr
|
||||
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
|
||||
|
||||
// Safety: layout sizes are correct
|
||||
let new_ptr = unsafe {
|
||||
self.inner
|
||||
.realloc(ptr, tagged_layout, new_tagged_layout.size())
|
||||
};
|
||||
|
||||
// allocation failed.
|
||||
if new_ptr.is_null() {
|
||||
return new_ptr;
|
||||
}
|
||||
|
||||
// We are being very careful here to not allocate or panic.
|
||||
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
|
||||
let new_tag = thread.and_then(|t| t.0).map_or(self.default_tag, Cell::get);
|
||||
|
||||
// Allocation successful. Write our tag
|
||||
// Safety: new_tag_offset is inbounds of the ptr
|
||||
unsafe { new_ptr.add(new_tag_offset).cast::<T>().write(new_tag) }
|
||||
|
||||
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
|
||||
let counters = counters.or_else(|| self.global.get());
|
||||
let (new_metric, old_metric) = if let Some(counters) = counters {
|
||||
let new_metric = counters.vec.get_metric(new_tag);
|
||||
let old_metric = counters.vec.get_metric(tag);
|
||||
|
||||
(new_metric, old_metric)
|
||||
} else {
|
||||
// no tag was registered at all, therefore both tags must be default.
|
||||
(&self.default_counters, &self.default_counters)
|
||||
};
|
||||
|
||||
let (inc, dec) = if tag.encode() != new_tag.encode() {
|
||||
(new_layout.size() as u64, layout.size() as u64)
|
||||
} else if new_layout.size() > layout.size() {
|
||||
((new_layout.size() - layout.size()) as u64, 0)
|
||||
} else {
|
||||
(0, (layout.size() - new_layout.size()) as u64)
|
||||
};
|
||||
|
||||
new_metric.inc.count.fetch_add(inc, Relaxed);
|
||||
old_metric.dec.count.fetch_add(dec, Relaxed);
|
||||
|
||||
new_ptr
|
||||
}
|
||||
|
||||
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
|
||||
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
|
||||
// Safety: This layout clearly did not match what was originally allocated,
|
||||
// otherwise alloc() would have caught this error and returned null.
|
||||
unsafe { std::hint::unreachable_unchecked() }
|
||||
};
|
||||
let tagged_layout = tagged_layout.pad_to_align();
|
||||
|
||||
// get the tag set during alloc
|
||||
// Safety: tag_offset is inbounds of the ptr
|
||||
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
|
||||
|
||||
// Safety: caller upholds contract for us
|
||||
unsafe { self.inner.dealloc(ptr, tagged_layout) }
|
||||
|
||||
// We are being very careful here to not allocate or panic.
|
||||
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
|
||||
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
|
||||
let counters = counters.or_else(|| self.global.get());
|
||||
|
||||
let metric = if let Some(counters) = counters {
|
||||
counters.vec.get_metric(tag)
|
||||
} else {
|
||||
// if tag is not default, then global would have been registered, therefore tag must be default.
|
||||
&self.default_counters
|
||||
};
|
||||
|
||||
metric.dec.count.fetch_add(layout.size() as u64, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AllocScope<'a, T: FixedCardinalityLabel> {
|
||||
cell: &'a Cell<T>,
|
||||
last: T,
|
||||
}
|
||||
|
||||
impl<'a, T: FixedCardinalityLabel> Drop for AllocScope<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.cell.set(self.last);
|
||||
}
|
||||
}
|
||||
|
||||
struct AllocPair<T>(PhantomData<T>);
|
||||
|
||||
impl<T: FixedCardinalityLabel + LabelGroup> CounterPairAssoc for AllocPair<T> {
|
||||
const INC_NAME: &'static MetricName = MetricName::from_str("allocated_bytes");
|
||||
const DEC_NAME: &'static MetricName = MetricName::from_str("deallocated_bytes");
|
||||
|
||||
const INC_HELP: &'static str = "total number of bytes allocated";
|
||||
const DEC_HELP: &'static str = "total number of bytes deallocated";
|
||||
|
||||
type LabelGroupSet = StaticLabelSet<T>;
|
||||
}
|
||||
|
||||
struct RegisteredThread<T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup> {
|
||||
/// Current memory context for this thread.
|
||||
scope: ThreadLocal<Cell<T>>,
|
||||
/// per thread state containing low contention counters for faster allocations.
|
||||
state: ThreadLocal<ThreadState<T>>,
|
||||
}
|
||||
|
||||
struct ThreadState<T: 'static + FixedCardinalityLabel + LabelGroup> {
|
||||
counters: AllocCounter<T>,
|
||||
global: &'static AllocCounter<T>,
|
||||
}
|
||||
|
||||
// Ensure the counters are measured on thread destruction.
|
||||
impl<T: 'static + FixedCardinalityLabel + LabelGroup> Drop for ThreadState<T> {
|
||||
fn drop(&mut self) {
|
||||
// iterate over all labels
|
||||
for tag in (0..T::cardinality()).map(T::decode) {
|
||||
// load and reset the counts in the thread-local counters.
|
||||
let m = self.counters.vec.get_metric_mut(tag);
|
||||
let inc = *m.inc.count.get_mut();
|
||||
let dec = *m.dec.count.get_mut();
|
||||
|
||||
// add the counts into the global counters.
|
||||
let m = self.global.vec.get_metric(tag);
|
||||
m.inc.count.fetch_add(inc, Relaxed);
|
||||
m.dec.count.fetch_add(dec, Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T, Enc> MetricGroup<Enc> for TrackedAllocator<A, T>
|
||||
where
|
||||
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
|
||||
Enc: Encoding,
|
||||
CounterState: MetricEncoding<Enc>,
|
||||
{
|
||||
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
|
||||
let global = self.global.get_or_init(AllocCounter::new);
|
||||
|
||||
// iterate over all counter threads
|
||||
for s in self.thread.get().into_iter().flat_map(|s| s.state.iter()) {
|
||||
// iterate over all labels
|
||||
for tag in (0..T::cardinality()).map(T::decode) {
|
||||
sample(global, s.counters.vec.get_metric(tag), tag);
|
||||
}
|
||||
}
|
||||
|
||||
sample(global, &self.default_counters, self.default_tag);
|
||||
|
||||
global.collect_group_into(enc)
|
||||
}
|
||||
}
|
||||
|
||||
fn sample<T: FixedCardinalityLabel + LabelGroup>(
|
||||
global: &AllocCounter<T>,
|
||||
local: &MeasuredCounterPairState,
|
||||
tag: T,
|
||||
) {
|
||||
// load and reset the counts in the thread-local counters.
|
||||
let inc = local.inc.count.swap(0, Relaxed);
|
||||
let dec = local.dec.count.swap(0, Relaxed);
|
||||
|
||||
// add the counts into the global counters.
|
||||
let m = global.vec.get_metric(tag);
|
||||
m.inc.count.fetch_add(inc, Relaxed);
|
||||
m.dec.count.fetch_add(dec, Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::alloc::{GlobalAlloc, Layout, System};
|
||||
|
||||
use measured::{FixedCardinalityLabel, MetricGroup, text::BufferedTextEncoder};
|
||||
|
||||
use crate::TrackedAllocator;
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
#[label(singleton = "memory_context")]
|
||||
pub enum MemoryContext {
|
||||
Root,
|
||||
Test,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alloc() {
|
||||
// Safety: `MemoryContext` upholds the safety requirements.
|
||||
static GLOBAL: TrackedAllocator<System, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
|
||||
|
||||
GLOBAL.register_thread();
|
||||
|
||||
let _test = GLOBAL.scope(MemoryContext::Test);
|
||||
|
||||
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
|
||||
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32]), 8) };
|
||||
|
||||
drop(_test);
|
||||
|
||||
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32, 1_i32]), 4) };
|
||||
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
|
||||
|
||||
let mut text = BufferedTextEncoder::new();
|
||||
GLOBAL.collect_group_into(&mut text).unwrap();
|
||||
let text = String::from_utf8(text.finish().into()).unwrap();
|
||||
assert_eq!(
|
||||
text,
|
||||
r#"# HELP deallocated_bytes total number of bytes deallocated
|
||||
# TYPE deallocated_bytes counter
|
||||
deallocated_bytes{memory_context="root"} 4
|
||||
deallocated_bytes{memory_context="test"} 8
|
||||
|
||||
# HELP allocated_bytes total number of bytes allocated
|
||||
# TYPE allocated_bytes counter
|
||||
allocated_bytes{memory_context="root"} 4
|
||||
allocated_bytes{memory_context="test"} 8
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregistered_thread() {
|
||||
// Safety: `MemoryContext` upholds the safety requirements.
|
||||
static GLOBAL: TrackedAllocator<System, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
|
||||
|
||||
GLOBAL.register_thread();
|
||||
|
||||
// unregistered thread
|
||||
std::thread::spawn(|| {
|
||||
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
|
||||
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
|
||||
})
|
||||
.join()
|
||||
.unwrap();
|
||||
|
||||
let mut text = BufferedTextEncoder::new();
|
||||
GLOBAL.collect_group_into(&mut text).unwrap();
|
||||
let text = String::from_utf8(text.finish().into()).unwrap();
|
||||
assert_eq!(
|
||||
text,
|
||||
r#"# HELP deallocated_bytes total number of bytes deallocated
|
||||
# TYPE deallocated_bytes counter
|
||||
deallocated_bytes{memory_context="root"} 4
|
||||
deallocated_bytes{memory_context="test"} 0
|
||||
|
||||
# HELP allocated_bytes total number of bytes allocated
|
||||
# TYPE allocated_bytes counter
|
||||
allocated_bytes{memory_context="root"} 4
|
||||
allocated_bytes{memory_context="test"} 0
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fully_unregistered() {
|
||||
// Safety: `MemoryContext` upholds the safety requirements.
|
||||
static GLOBAL: TrackedAllocator<System, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
|
||||
|
||||
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
|
||||
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
|
||||
|
||||
let mut text = BufferedTextEncoder::new();
|
||||
GLOBAL.collect_group_into(&mut text).unwrap();
|
||||
let text = String::from_utf8(text.finish().into()).unwrap();
|
||||
assert_eq!(
|
||||
text,
|
||||
r#"# HELP deallocated_bytes total number of bytes deallocated
|
||||
# TYPE deallocated_bytes counter
|
||||
deallocated_bytes{memory_context="root"} 4
|
||||
deallocated_bytes{memory_context="test"} 0
|
||||
|
||||
# HELP allocated_bytes total number of bytes allocated
|
||||
# TYPE allocated_bytes counter
|
||||
allocated_bytes{memory_context="root"} 4
|
||||
allocated_bytes{memory_context="test"} 0
|
||||
"#
|
||||
);
|
||||
}
|
||||
}
|
||||
72
libs/alloc-metrics/src/metric_vec.rs
Normal file
72
libs/alloc-metrics/src/metric_vec.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
//! Dense metric vec
|
||||
|
||||
use measured::{
|
||||
FixedCardinalityLabel, LabelGroup,
|
||||
label::StaticLabelSet,
|
||||
metric::{
|
||||
MetricEncoding, MetricFamilyEncoding, MetricType, group::Encoding, name::MetricNameEncoder,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct DenseMetricVec<M: MetricType, L: FixedCardinalityLabel + LabelGroup> {
|
||||
metrics: Box<[M]>,
|
||||
metadata: M::Metadata,
|
||||
_label_set: StaticLabelSet<L>,
|
||||
}
|
||||
|
||||
fn new_dense<M: MetricType>(c: usize) -> Box<[M]> {
|
||||
let mut vec = Vec::with_capacity(c);
|
||||
vec.resize_with(c, M::default);
|
||||
vec.into_boxed_slice()
|
||||
}
|
||||
|
||||
impl<M: MetricType, L: FixedCardinalityLabel + LabelGroup> DenseMetricVec<M, L>
|
||||
where
|
||||
M::Metadata: Default,
|
||||
{
|
||||
/// Create a new metric vec with the given label set and metric metadata
|
||||
pub fn new() -> Self {
|
||||
Self::with_metadata(<M::Metadata>::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: MetricType, L: FixedCardinalityLabel + LabelGroup> DenseMetricVec<M, L> {
|
||||
/// Create a new metric vec with the given label set and metric metadata
|
||||
pub fn with_metadata(metadata: M::Metadata) -> Self {
|
||||
Self {
|
||||
metrics: new_dense(L::cardinality()),
|
||||
metadata,
|
||||
_label_set: StaticLabelSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the individual metric at the given identifier.
|
||||
///
|
||||
/// # Panics
|
||||
/// Can panic or cause strange behaviour if the label ID comes from a different metric family.
|
||||
pub fn get_metric(&self, label: L) -> &M {
|
||||
// safety: The caller has guarantees that the label encoding is valid.
|
||||
unsafe { self.metrics.get_unchecked(label.encode()) }
|
||||
}
|
||||
|
||||
/// Get the individual metric at the given identifier.
|
||||
///
|
||||
/// # Panics
|
||||
/// Can panic or cause strange behaviour if the label ID comes from a different metric family.
|
||||
pub fn get_metric_mut(&mut self, label: L) -> &mut M {
|
||||
// safety: The caller has guarantees that the label encoding is valid.
|
||||
unsafe { self.metrics.get_unchecked_mut(label.encode()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: MetricEncoding<T>, L: FixedCardinalityLabel + LabelGroup, T: Encoding>
|
||||
MetricFamilyEncoding<T> for DenseMetricVec<M, L>
|
||||
{
|
||||
fn collect_family_into(&self, name: impl MetricNameEncoder, enc: &mut T) -> Result<(), T::Err> {
|
||||
M::write_type(&name, enc)?;
|
||||
for (index, value) in self.metrics.iter().enumerate() {
|
||||
value.collect_into(&self.metadata, L::decode(index), &name, enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -478,7 +478,7 @@ pub trait CounterPairAssoc {
|
||||
}
|
||||
|
||||
pub struct CounterPairVec<A: CounterPairAssoc> {
|
||||
vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
|
||||
pub vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
|
||||
}
|
||||
|
||||
impl<A: CounterPairAssoc> Default for CounterPairVec<A>
|
||||
@@ -492,6 +492,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: CounterPairAssoc> CounterPairVec<A>
|
||||
where
|
||||
A::LabelGroupSet: Default,
|
||||
{
|
||||
pub fn dense() -> Self {
|
||||
Self {
|
||||
vec: measured::metric::MetricVec::dense(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: CounterPairAssoc> CounterPairVec<A> {
|
||||
pub fn guard(
|
||||
&self,
|
||||
@@ -501,14 +512,31 @@ impl<A: CounterPairAssoc> CounterPairVec<A> {
|
||||
self.vec.get_metric(id).inc.inc();
|
||||
MeasuredCounterPairGuard { vec: &self.vec, id }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn inc(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
|
||||
let id = self.vec.with_labels(labels);
|
||||
self.vec.get_metric(id).inc.inc();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn dec(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
|
||||
let id = self.vec.with_labels(labels);
|
||||
self.vec.get_metric(id).dec.inc();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn inc_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
|
||||
let id = self.vec.with_labels(labels);
|
||||
self.vec.get_metric(id).inc.inc_by(x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn dec_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
|
||||
let id = self.vec.with_labels(labels);
|
||||
self.vec.get_metric(id).dec.inc_by(x);
|
||||
}
|
||||
|
||||
pub fn remove_metric(
|
||||
&self,
|
||||
labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>,
|
||||
@@ -553,6 +581,28 @@ pub struct MeasuredCounterPairState {
|
||||
pub dec: CounterState,
|
||||
}
|
||||
|
||||
impl MeasuredCounterPairState {
|
||||
#[inline]
|
||||
pub fn inc(&self) {
|
||||
self.inc.inc();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn dec(&self) {
|
||||
self.dec.inc();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn inc_by(&self, x: u64) {
|
||||
self.inc.inc_by(x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn dec_by(&self, x: u64) {
|
||||
self.dec.inc_by(x);
|
||||
}
|
||||
}
|
||||
|
||||
impl measured::metric::MetricType for MeasuredCounterPairState {
|
||||
type Metadata = ();
|
||||
}
|
||||
@@ -569,9 +619,9 @@ impl<A: CounterPairAssoc> Drop for MeasuredCounterPairGuard<'_, A> {
|
||||
}
|
||||
|
||||
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder.
|
||||
struct Inc<T>(T);
|
||||
pub struct Inc<T>(pub T);
|
||||
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder.
|
||||
struct Dec<T>(T);
|
||||
pub struct Dec<T>(pub T);
|
||||
|
||||
impl<T: Encoding> Encoding for Inc<T> {
|
||||
type Err = T::Err;
|
||||
|
||||
@@ -10,6 +10,7 @@ testing = ["dep:tokio-postgres"]
|
||||
|
||||
[dependencies]
|
||||
ahash.workspace = true
|
||||
alloc-metrics.workspace = true
|
||||
anyhow.workspace = true
|
||||
arc-swap.workspace = true
|
||||
async-compression.workspace = true
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
use alloc_metrics::TrackedAllocator;
|
||||
use proxy::binary::proxy::MemoryContext;
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
// Safety: `MemoryContext` upholds the safety requirements.
|
||||
static GLOBAL: TrackedAllocator<Jemalloc, MemoryContext> =
|
||||
unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) };
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
#[unsafe(export_name = "malloc_conf")]
|
||||
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:21\0";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
proxy::binary::proxy::run().await
|
||||
fn main() -> anyhow::Result<()> {
|
||||
GLOBAL.register_thread();
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.on_thread_start(|| GLOBAL.register_thread())
|
||||
.build()
|
||||
.expect("Failed building the Runtime")
|
||||
.block_on(proxy::binary::proxy::run(&GLOBAL))
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
|
||||
|
||||
// TODO: refactor these to use labels
|
||||
debug!("Version: {GIT_VERSION}");
|
||||
|
||||
@@ -80,7 +80,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
|
||||
|
||||
let args = cli().get_matches();
|
||||
let destination: String = args
|
||||
|
||||
@@ -39,7 +39,8 @@ use crate::config::{
|
||||
};
|
||||
use crate::context::parquet::ParquetUploadArgs;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
use crate::metrics::Metrics;
|
||||
pub use crate::metrics::MemoryContext;
|
||||
use crate::metrics::{Alloc, Metrics};
|
||||
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::redis::kv_ops::RedisKVClient;
|
||||
@@ -318,7 +319,7 @@ struct PgSniRouterArgs {
|
||||
dest: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn run() -> anyhow::Result<()> {
|
||||
pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> {
|
||||
let _logging_guard = crate::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
@@ -340,7 +341,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let config = build_config(&args, alloc)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
match auth_backend {
|
||||
@@ -589,9 +590,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
fn build_config(
|
||||
args: &ProxyCliArgs,
|
||||
alloc: &'static Alloc,
|
||||
) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
Metrics::install(thread_pool.metrics.clone(), Some(alloc));
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
|
||||
286
proxy/src/cache/project_info.rs
vendored
286
proxy/src/cache/project_info.rs
vendored
@@ -10,6 +10,7 @@ use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
|
||||
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
@@ -36,22 +37,37 @@ impl<T> Entry<T> {
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self) -> Option<&T> {
|
||||
(self.expires_at > Instant::now()).then_some(&self.value)
|
||||
(!self.is_expired()).then_some(&self.value)
|
||||
}
|
||||
|
||||
fn is_expired(&self) -> bool {
|
||||
self.expires_at <= Instant::now()
|
||||
}
|
||||
}
|
||||
|
||||
struct EndpointInfo {
|
||||
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
|
||||
controls: Option<Entry<EndpointAccessControl>>,
|
||||
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
|
||||
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
|
||||
}
|
||||
|
||||
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
|
||||
|
||||
impl EndpointInfo {
|
||||
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
|
||||
self.role_controls.get(&role_name)?.get().cloned()
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
&self,
|
||||
role_name: RoleNameInt,
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
let entry = self.role_controls.get(&role_name)?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
|
||||
self.controls.as_ref()?.get().cloned()
|
||||
pub(crate) fn get_controls_with_ttl(
|
||||
&self,
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let entry = self.controls.as_ref()?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_endpoint(&mut self) {
|
||||
@@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl {
|
||||
self.cache.get(&endpoint_id)
|
||||
}
|
||||
|
||||
pub(crate) fn get_role_secret(
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
) -> Option<RoleAccessControl> {
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
let role_name = RoleNameInt::get(role_name)?;
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_role_secret(role_name)
|
||||
endpoint_info.get_role_secret_with_ttl(role_name)
|
||||
}
|
||||
|
||||
pub(crate) fn get_endpoint_access(
|
||||
pub(crate) fn get_endpoint_access_with_ttl(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<EndpointAccessControl> {
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_controls()
|
||||
endpoint_info.get_controls_with_ttl()
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access(
|
||||
&self,
|
||||
account_id: Option<AccountIdInt>,
|
||||
project_id: ProjectIdInt,
|
||||
project_id: Option<ProjectIdInt>,
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
controls: EndpointAccessControl,
|
||||
@@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl {
|
||||
if let Some(account_id) = account_id {
|
||||
self.insert_account2endpoint(account_id, endpoint_id);
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
if let Some(project_id) = project_id {
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
}
|
||||
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
let controls = Entry::new(controls, self.config.ttl);
|
||||
let role_controls = Entry::new(role_controls, self.config.ttl);
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for endpoint access"
|
||||
);
|
||||
|
||||
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
|
||||
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls: Some(controls),
|
||||
controls,
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
ep.controls = Some(controls);
|
||||
ep.controls = controls;
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access_err(
|
||||
&self,
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
msg: Box<ControlPlaneErrorMessage>,
|
||||
ttl: Option<Duration>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for an endpoint access error"
|
||||
);
|
||||
|
||||
let ttl = ttl.unwrap_or(self.config.ttl);
|
||||
|
||||
let controls = if msg.get_reason() == Reason::RoleProtected {
|
||||
// RoleProtected is the only role-specific error that control plane can give us.
|
||||
// If a given role name does not exist, it still returns a successful response,
|
||||
// just with an empty secret.
|
||||
None
|
||||
} else {
|
||||
// We can cache all the other errors in EndpointInfo.controls,
|
||||
// because they don't depend on what role name we pass to control plane.
|
||||
Some(Entry::new(Err(msg.clone()), ttl))
|
||||
};
|
||||
|
||||
let role_controls = Entry::new(Err(msg), ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls,
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
if let Some(entry) = &ep.controls
|
||||
&& !entry.is_expired()
|
||||
&& entry.value.is_ok()
|
||||
{
|
||||
// If we have cached non-expired, non-error controls, keep them.
|
||||
} else {
|
||||
ep.controls = controls;
|
||||
}
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
@@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl {
|
||||
return;
|
||||
};
|
||||
|
||||
if role_controls.get().expires_at <= Instant::now() {
|
||||
if role_controls.get().is_expired() {
|
||||
role_controls.remove();
|
||||
}
|
||||
}
|
||||
@@ -284,13 +363,11 @@ impl ProjectInfoCacheImpl {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::types::ProjectId;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
@@ -301,9 +378,9 @@ mod tests {
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id: ProjectId = "project".into();
|
||||
let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let account_id: Option<AccountIdInt> = None;
|
||||
let account_id = None;
|
||||
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
@@ -316,7 +393,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -332,7 +409,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -346,11 +423,17 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.secret, secret1);
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user1)
|
||||
.unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret1);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert_eq!(cached.secret, secret2);
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret2);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: RoleName = "user3".into();
|
||||
@@ -358,7 +441,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user3).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -372,17 +455,144 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user3)
|
||||
.is_none()
|
||||
);
|
||||
|
||||
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
|
||||
let cached = cache
|
||||
.get_endpoint_access_with_ttl(&endpoint_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.unwrap();
|
||||
assert_eq!(cached.allowed_ips, allowed_ips);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1);
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2);
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_endpoint_access(&endpoint_id);
|
||||
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_caching_project_info_errors() {
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 10,
|
||||
max_roles: 10,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id = Some(ProjectIdInt::from(&"project".into()));
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let account_id = None;
|
||||
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
|
||||
let role_msg = Box::new(ControlPlaneErrorMessage {
|
||||
error: "role is protected and cannot be used for password-based authentication"
|
||||
.to_owned()
|
||||
.into_boxed_str(),
|
||||
http_status_code: http::StatusCode::NOT_FOUND,
|
||||
status: Some(Status {
|
||||
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
|
||||
message: "role is protected and cannot be used for password-based authentication"
|
||||
.to_owned()
|
||||
.into_boxed_str(),
|
||||
details: Details {
|
||||
error_info: Some(ErrorInfo {
|
||||
reason: Reason::RoleProtected,
|
||||
}),
|
||||
retry_info: None,
|
||||
user_facing_message: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
let generic_msg = Box::new(ControlPlaneErrorMessage {
|
||||
error: "oh noes".to_owned().into_boxed_str(),
|
||||
http_status_code: http::StatusCode::NOT_FOUND,
|
||||
status: None,
|
||||
});
|
||||
|
||||
let get_role_secret = |endpoint_id, role_name| {
|
||||
cache
|
||||
.get_role_secret_with_ttl(endpoint_id, role_name)
|
||||
.unwrap()
|
||||
.0
|
||||
};
|
||||
let get_endpoint_access =
|
||||
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
|
||||
|
||||
// stores role-specific errors only for get_role_secret
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
role_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
role_msg.error
|
||||
);
|
||||
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
|
||||
|
||||
// stores non-role specific errors for both get_role_secret and get_endpoint_access
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
generic_msg.error
|
||||
);
|
||||
assert_eq!(
|
||||
get_endpoint_access(&endpoint_id).unwrap_err().error,
|
||||
generic_msg.error
|
||||
);
|
||||
|
||||
// error isn't returned for other roles in the same endpoint
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.is_none()
|
||||
);
|
||||
|
||||
// success for a role does not overwrite errors for other roles
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
EndpointAccessControl {
|
||||
allowed_ips: Arc::new(vec![]),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret.clone(),
|
||||
},
|
||||
);
|
||||
assert!(get_role_secret(&endpoint_id, &user1).is_err());
|
||||
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
|
||||
// ...but does clear the access control error
|
||||
assert!(get_endpoint_access(&endpoint_id).is_ok());
|
||||
|
||||
// storing an error does not overwrite successful access control response
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert!(get_role_secret(&endpoint_id, &user2).is_err());
|
||||
assert!(get_endpoint_access(&endpoint_id).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,66 @@ impl NeonControlPlaneClient {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn get_and_cache_auth_info<T>(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
cache_key: &EndpointId,
|
||||
extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T,
|
||||
) -> Result<T, GetAuthInfoError> {
|
||||
match self.do_get_auth_req(ctx, endpoint, role).await {
|
||||
Ok(auth_info) => {
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
let res = extract(&control, &role_control);
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
auth_info.project_id,
|
||||
cache_key.into(),
|
||||
role.into(),
|
||||
control,
|
||||
role_control,
|
||||
);
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
Err(err) => match err {
|
||||
GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => {
|
||||
let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
|
||||
|
||||
// If we can retry this error, do not cache it,
|
||||
// unless we were given a retry delay.
|
||||
if msg.could_retry() && retry_info.is_none() {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
self.caches.project_info.insert_endpoint_access_err(
|
||||
cache_key.into(),
|
||||
role.into(),
|
||||
msg.clone(),
|
||||
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
|
||||
);
|
||||
|
||||
Err(err)
|
||||
}
|
||||
err => Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_get_auth_req(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -284,43 +344,34 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(secret) = self
|
||||
) -> Result<RoleAccessControl, GetAuthInfoError> {
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((role_control, ttl)) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, role)
|
||||
.get_role_secret_with_ttl(&key, role)
|
||||
{
|
||||
return Ok(secret);
|
||||
return match role_control {
|
||||
Err(mut msg) => {
|
||||
info!(key = &*key, "found cached get_role_access_control error");
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(role_control) => {
|
||||
debug!(key = &*key, "found cached role access control");
|
||||
Ok(role_control)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
role.into(),
|
||||
control,
|
||||
role_control.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(role_control)
|
||||
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| {
|
||||
role_control.clone()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -330,38 +381,30 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
|
||||
return Ok(control);
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
|
||||
return match control {
|
||||
Err(mut msg) => {
|
||||
info!(
|
||||
key = &*key,
|
||||
"found cached get_endpoint_access_control error"
|
||||
);
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(control) => {
|
||||
debug!(key = &*key, "found cached endpoint access control");
|
||||
Ok(control)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
role.into(),
|
||||
control.clone(),
|
||||
role_control,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(control)
|
||||
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone())
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -390,13 +433,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
|
||||
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
|
||||
if let Some(status) = &mut msg.status {
|
||||
if let Some(retry_info) = &mut status.details.retry_info {
|
||||
retry_info.retry_delay_ms = retry_info
|
||||
.retry_delay_ms
|
||||
.saturating_sub(created_at.elapsed().as_millis() as u64)
|
||||
}
|
||||
}
|
||||
replace_retry_delay_ms(&mut msg, |delay| {
|
||||
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
|
||||
});
|
||||
|
||||
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
msg,
|
||||
@@ -478,6 +517,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
|
||||
if let Some(status) = &mut msg.status
|
||||
&& let Some(retry_info) = &mut status.details.retry_info
|
||||
{
|
||||
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
status: StatusCode,
|
||||
|
||||
@@ -52,7 +52,7 @@ impl ReportableError for ControlPlaneError {
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::EndpointDisabled
|
||||
| Reason::BranchNotFound
|
||||
| Reason::InvalidEphemeralEndpointOptions => ErrorKind::User,
|
||||
| Reason::WrongLsnOrTimestamp => ErrorKind::User,
|
||||
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ pub(crate) struct ErrorInfo {
|
||||
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default)]
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
|
||||
pub(crate) enum Reason {
|
||||
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
|
||||
#[serde(rename = "ROLE_PROTECTED")]
|
||||
@@ -133,9 +133,9 @@ pub(crate) enum Reason {
|
||||
/// or that the subject doesn't have enough permissions to access the requested branch.
|
||||
#[serde(rename = "BRANCH_NOT_FOUND")]
|
||||
BranchNotFound,
|
||||
/// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong.
|
||||
#[serde(rename = "INVALID_EPHEMERAL_OPTIONS")]
|
||||
InvalidEphemeralEndpointOptions,
|
||||
/// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong.
|
||||
#[serde(rename = "WRONG_LSN_OR_TIMESTAMP")]
|
||||
WrongLsnOrTimestamp,
|
||||
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
|
||||
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
|
||||
RateLimitExceeded,
|
||||
@@ -205,7 +205,7 @@ impl Reason {
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::EndpointDisabled
|
||||
| Reason::BranchNotFound
|
||||
| Reason::InvalidEphemeralEndpointOptions => false,
|
||||
| Reason::WrongLsnOrTimestamp => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
@@ -257,19 +257,19 @@ pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
#[derive(Copy, Clone, Deserialize, Default, Debug)]
|
||||
pub struct EndpointRateLimitConfig {
|
||||
pub connection_attempts: ConnectionAttemptsLimit,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
#[derive(Copy, Clone, Deserialize, Default, Debug)]
|
||||
pub struct ConnectionAttemptsLimit {
|
||||
pub tcp: Option<LeakyBucketSetting>,
|
||||
pub ws: Option<LeakyBucketSetting>,
|
||||
pub http: Option<LeakyBucketSetting>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize)]
|
||||
#[derive(Copy, Clone, Deserialize, Debug)]
|
||||
pub struct LeakyBucketSetting {
|
||||
pub rps: f64,
|
||||
pub burst: f64,
|
||||
|
||||
@@ -82,7 +82,7 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default)]
|
||||
#[derive(Copy, Clone, Default, Debug)]
|
||||
pub(crate) struct AccessBlockerFlags {
|
||||
pub public_access_blocked: bool,
|
||||
pub vpc_access_blocked: bool,
|
||||
@@ -92,12 +92,12 @@ pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoleAccessControl {
|
||||
pub secret: Option<AuthSecret>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EndpointAccessControl {
|
||||
pub allowed_ips: Arc<Vec<IpPattern>>,
|
||||
pub allowed_vpce: Arc<Vec<String>>,
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
#![expect(
|
||||
clippy::ref_option_ref,
|
||||
reason = "generated from measured derived output"
|
||||
)]
|
||||
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use alloc_metrics::TrackedAllocator;
|
||||
use lasso::ThreadedRodeo;
|
||||
use measured::label::{
|
||||
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
|
||||
@@ -11,26 +17,33 @@ use measured::{
|
||||
MetricGroup,
|
||||
};
|
||||
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec};
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
use tokio::time::{self, Instant};
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::error::ErrorKind;
|
||||
|
||||
pub type Alloc = TrackedAllocator<Jemalloc, MemoryContext>;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>, alloc: Option<&'static Alloc>))]
|
||||
pub struct Metrics {
|
||||
#[metric(namespace = "proxy")]
|
||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
||||
pub proxy: ProxyMetrics,
|
||||
|
||||
#[metric(namespace = "alloc")]
|
||||
#[metric(init = alloc)]
|
||||
pub alloc: Option<&'static Alloc>,
|
||||
|
||||
#[metric(namespace = "wake_compute_lock")]
|
||||
pub wake_compute_lock: ApiLockMetrics,
|
||||
}
|
||||
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
impl Metrics {
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
||||
let mut metrics = Metrics::new(thread_pool);
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>, alloc: Option<&'static Alloc>) {
|
||||
let mut metrics = Metrics::new(thread_pool, alloc);
|
||||
|
||||
metrics.proxy.errors_total.init_all_dense();
|
||||
metrics.proxy.redis_errors_total.init_all_dense();
|
||||
@@ -45,7 +58,7 @@ impl Metrics {
|
||||
|
||||
pub fn get() -> &'static Self {
|
||||
#[cfg(test)]
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)), None));
|
||||
|
||||
#[cfg(not(test))]
|
||||
SELF.get()
|
||||
@@ -660,3 +673,9 @@ pub struct ThreadPoolMetrics {
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
#[label(singleton = "context")]
|
||||
pub enum MemoryContext {
|
||||
Root,
|
||||
}
|
||||
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: ac3c460e01...47304b9215
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 24313bf8f3...cef72d5308
2
vendor/postgres-v16
vendored
2
vendor/postgres-v16
vendored
Submodule vendor/postgres-v16 updated: 51194dc5ce...e9db1ff5a6
2
vendor/postgres-v17
vendored
2
vendor/postgres-v17
vendored
Submodule vendor/postgres-v17 updated: eac5279cd1...a50d80c750
8
vendor/revisions.json
vendored
8
vendor/revisions.json
vendored
@@ -1,18 +1,18 @@
|
||||
{
|
||||
"v17": [
|
||||
"17.5",
|
||||
"eac5279cd147d4086e0eb242198aae2f4b766d7b"
|
||||
"a50d80c7507e8ae9fc37bf1869051cf2d51370ab"
|
||||
],
|
||||
"v16": [
|
||||
"16.9",
|
||||
"51194dc5ce2e3523068d8607852e6c3125a17e58"
|
||||
"e9db1ff5a6f3ca18f626ba3d62ab475e6c688a96"
|
||||
],
|
||||
"v15": [
|
||||
"15.13",
|
||||
"24313bf8f3de722968a2fdf764de7ef77ed64f06"
|
||||
"cef72d5308ddce3795a9043fcd94f8849f7f4800"
|
||||
],
|
||||
"v14": [
|
||||
"14.18",
|
||||
"ac3c460e01a31f11fb52fd8d8e88e60f0e1069b4"
|
||||
"47304b921555b3f33eb3b49daada3078e774cfd7"
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user