create memory context allocator tracking

This commit is contained in:
Conrad Ludgate
2025-07-20 17:08:50 +01:00
parent 791b5d736b
commit 40f5b3e8df
13 changed files with 1088 additions and 16 deletions

View File

@@ -30,6 +30,7 @@ workspace-members = [
"vm_monitor",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"alloc-metrics",
"compute_api",
"consumption_metrics",
"desim",

9
Cargo.lock generated
View File

@@ -61,6 +61,14 @@ dependencies = [
"equator",
]
[[package]]
name = "alloc-metrics"
version = "0.1.0"
dependencies = [
"measured",
"metrics",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
@@ -5301,6 +5309,7 @@ name = "proxy"
version = "0.1.0"
dependencies = [
"ahash",
"alloc-metrics",
"anyhow",
"arc-swap",
"assert-json-diff",

View File

@@ -253,6 +253,7 @@ azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git"
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
## Local libraries
alloc-metrics = { version = "0.1", path = "./libs/alloc-metrics/" }
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
desim = { version = "0.1", path = "./libs/desim" }

View File

@@ -0,0 +1,9 @@
[package]
name = "alloc-metrics"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
metrics.workspace = true
measured.workspace = true

View File

@@ -0,0 +1,394 @@
//! Tagged allocator measurements.
mod thread_local;
use std::{
alloc::{GlobalAlloc, Layout},
cell::Cell,
marker::PhantomData,
sync::{
OnceLock,
atomic::{AtomicU64, Ordering::Relaxed},
},
};
use measured::{
FixedCardinalityLabel, LabelGroup, MetricGroup,
label::StaticLabelSet,
metric::{MetricEncoding, counter::CounterState, group::Encoding, name::MetricName},
};
use metrics::{CounterPairAssoc, CounterPairVec, MeasuredCounterPairState};
use thread_local::ThreadLocal;
type AllocCounter<T> = CounterPairVec<AllocPair<T>>;
pub struct TrackedAllocator<A, T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup> {
inner: A,
/// potentially high-content fallback if the thread was not registered.
default_counters: MeasuredCounterPairState,
/// Default tag to use if this thread is not registered.
default_tag: T,
/// Current memory context for this thread.
thread_scope: OnceLock<ThreadLocal<Cell<T>>>,
/// per thread state containing low contention counters for faster allocations.
thread_state: OnceLock<ThreadLocal<ThreadState<T>>>,
/// where thread alloc data is eventually saved to, even if threads are shutdown.
global: OnceLock<AllocCounter<T>>,
}
impl<A, T> TrackedAllocator<A, T>
where
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
{
/// # Safety
///
/// [`FixedCardinalityLabel`] must be implemented correctly, fully dense, and must not panic.
pub const unsafe fn new(alloc: A, default: T) -> Self {
TrackedAllocator {
inner: alloc,
default_tag: default,
default_counters: MeasuredCounterPairState {
inc: CounterState {
count: AtomicU64::new(0),
},
dec: CounterState {
count: AtomicU64::new(0),
},
},
thread_scope: OnceLock::new(),
thread_state: OnceLock::new(),
global: OnceLock::new(),
}
}
/// Allocations
pub fn register_thread(&'static self) {
self.register_thread_inner();
}
pub fn scope(&'static self, tag: T) -> AllocScope<'static, T> {
let cell = self.register_thread_inner();
let last = cell.replace(tag);
AllocScope { cell, last }
}
fn register_thread_inner(&'static self) -> &'static Cell<T> {
self.thread_state
.get_or_init(ThreadLocal::new)
.get_or(|| ThreadState {
counters: CounterPairVec::dense(),
global: self.global.get_or_init(CounterPairVec::dense),
});
self.thread_scope
.get_or_init(ThreadLocal::new)
.get_or(|| Cell::new(self.default_tag))
}
fn current_counters_alloc_safe(&self) -> Option<&AllocCounter<T>> {
// We are being very careful here to not allocate or panic.
self.thread_state
.get()
.and_then(ThreadLocal::get)
.map(|s| &s.counters)
.or_else(|| self.global.get())
}
fn current_tag_alloc_safe(&self) -> T {
// We are being very careful here to not allocate or panic.
self.thread_scope
.get()
.and_then(ThreadLocal::get)
.map_or(self.default_tag, Cell::get)
}
}
impl<A, T> TrackedAllocator<A, T>
where
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
{
unsafe fn alloc_inner(&self, layout: Layout, alloc: impl FnOnce(Layout) -> *mut u8) -> *mut u8 {
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
return std::ptr::null_mut();
};
let tagged_layout = tagged_layout.pad_to_align();
// Safety: The layout is not zero-sized.
let ptr = alloc(tagged_layout);
// allocation failed.
if ptr.is_null() {
return ptr;
}
let tag = self.current_tag_alloc_safe();
// Allocation successful. Write our tag
// Safety: tag_offset is inbounds of the ptr
unsafe { ptr.add(tag_offset).cast::<T>().write(tag) }
if let Some(counters) = self.current_counters_alloc_safe() {
// During `Self::new`, the caller has guaranteed that tag encoding will not panic.
counters.inc_by(tag, layout.size() as u64);
} else {
self.default_counters.inc_by(layout.size() as u64);
}
ptr
}
}
// We will tag our allocation by adding `T` to the end of the layout.
// This is ok only as long as it does not overflow. If it does, we will
// just fail the allocation by returning null.
//
// Safety: we will not unwind during alloc, and we will ensure layouts are handled correctly.
unsafe impl<A, T> GlobalAlloc for TrackedAllocator<A, T>
where
A: GlobalAlloc,
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
{
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
// safety: same as caller
unsafe { self.alloc_inner(layout, |tagged_layout| self.inner.alloc(tagged_layout)) }
}
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
// safety: same as caller
unsafe {
self.alloc_inner(layout, |tagged_layout| {
self.inner.alloc_zeroed(tagged_layout)
})
}
}
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
// SAFETY: the caller must ensure that the `new_size` does not overflow.
// `layout.align()` comes from a `Layout` and is thus guaranteed to be valid.
let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
let Ok((new_tagged_layout, new_tag_offset)) = new_layout.extend(Layout::new::<T>()) else {
return std::ptr::null_mut();
};
let new_tagged_layout = new_tagged_layout.pad_to_align();
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
// Safety: This layout clearly did not match what was originally allocated,
// otherwise alloc() would have caught this error and returned null.
unsafe { std::hint::unreachable_unchecked() }
};
let tagged_layout = tagged_layout.pad_to_align();
// get the tag set during alloc
// Safety: tag_offset is inbounds of the ptr
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
// Safety: layout sizes are correct
let new_ptr = unsafe {
self.inner
.realloc(ptr, tagged_layout, new_tagged_layout.size())
};
// allocation failed.
if new_ptr.is_null() {
return new_ptr;
}
let new_tag = self.current_tag_alloc_safe();
// Allocation successful. Write our tag
// Safety: new_tag_offset is inbounds of the ptr
unsafe { new_ptr.add(new_tag_offset).cast::<T>().write(new_tag) }
if let Some(counters) = self.current_counters_alloc_safe() {
if tag.encode() == new_tag.encode() {
let diff = usize::abs_diff(new_layout.size(), layout.size()) as u64;
if new_layout.size() > layout.size() {
counters.inc_by(tag, diff);
} else {
counters.dec_by(tag, diff);
}
} else {
counters.inc_by(new_tag, new_layout.size() as u64);
counters.dec_by(tag, layout.size() as u64);
}
} else {
// no tag was registered at all, therefore both tags must be default.
let diff = usize::abs_diff(new_layout.size(), layout.size()) as u64;
if new_layout.size() > layout.size() {
self.default_counters.inc_by(diff);
} else {
self.default_counters.dec_by(diff);
}
}
new_ptr
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
// Safety: This layout clearly did not match what was originally allocated,
// otherwise alloc() would have caught this error and returned null.
unsafe { std::hint::unreachable_unchecked() }
};
let tagged_layout = tagged_layout.pad_to_align();
// get the tag set during alloc
// Safety: tag_offset is inbounds of the ptr
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
if let Some(counters) = self.current_counters_alloc_safe() {
counters.dec_by(tag, layout.size() as u64);
} else {
// if tag is not default, then global would have been registered,
// therefore tag must be default.
self.default_counters.dec_by(layout.size() as u64);
}
// Safety: caller upholds contract for us
unsafe { self.inner.dealloc(ptr, tagged_layout) }
}
}
pub struct AllocScope<'a, T: FixedCardinalityLabel> {
cell: &'a Cell<T>,
last: T,
}
impl<'a, T: FixedCardinalityLabel> Drop for AllocScope<'a, T> {
fn drop(&mut self) {
self.cell.set(self.last);
}
}
struct AllocPair<T>(PhantomData<T>);
impl<T: FixedCardinalityLabel + LabelGroup> CounterPairAssoc for AllocPair<T> {
const INC_NAME: &'static MetricName = MetricName::from_str("allocated_bytes");
const DEC_NAME: &'static MetricName = MetricName::from_str("deallocated_bytes");
const INC_HELP: &'static str = "total number of bytes allocated";
const DEC_HELP: &'static str = "total number of bytes deallocated";
type LabelGroupSet = StaticLabelSet<T>;
}
struct ThreadState<T: 'static + FixedCardinalityLabel + LabelGroup> {
counters: AllocCounter<T>,
global: &'static AllocCounter<T>,
}
// Ensure the counters are measured on thread destruction.
impl<T: 'static + FixedCardinalityLabel + LabelGroup> Drop for ThreadState<T> {
fn drop(&mut self) {
// iterate over all labels
for tag in (0..T::cardinality()).map(T::decode) {
// load and reset the counts in the thread-local counters.
let id = self.counters.vec.with_labels(tag);
let mut m = self.counters.vec.get_metric_mut(id);
let inc = *m.inc.count.get_mut();
let dec = *m.dec.count.get_mut();
// add the counts into the global counters.
let id = self.global.vec.with_labels(tag);
let m = self.global.vec.get_metric(id);
m.inc.count.fetch_add(inc, Relaxed);
m.dec.count.fetch_add(dec, Relaxed);
}
}
}
impl<A, T, Enc> MetricGroup<Enc> for TrackedAllocator<A, T>
where
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
Enc: Encoding,
CounterState: MetricEncoding<Enc>,
{
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
let global = self.global.get_or_init(CounterPairVec::dense);
// iterate over all counter threads
for s in self.thread_state.get().into_iter().flat_map(|s| s.iter()) {
// iterate over all labels
for tag in (0..T::cardinality()).map(T::decode) {
let id = s.counters.vec.with_labels(tag);
sample(global, &s.counters.vec.get_metric(id), tag);
}
}
sample(global, &self.default_counters, self.default_tag);
global.collect_group_into(enc)
}
}
fn sample<T: FixedCardinalityLabel + LabelGroup>(
global: &AllocCounter<T>,
local: &MeasuredCounterPairState,
tag: T,
) {
// load and reset the counts in the thread-local counters.
let inc = local.inc.count.swap(0, Relaxed);
let dec = local.dec.count.swap(0, Relaxed);
// add the counts into the global counters.
let id = global.vec.with_labels(tag);
let m = global.vec.get_metric(id);
m.inc.count.fetch_add(inc, Relaxed);
m.dec.count.fetch_add(dec, Relaxed);
}
#[cfg(test)]
mod tests {
use std::alloc::{GlobalAlloc, Layout, System};
use measured::{FixedCardinalityLabel, MetricGroup, text::BufferedTextEncoder};
use crate::TrackedAllocator;
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "memory_context")]
pub enum MemoryContext {
Root,
Test,
}
#[test]
fn alloc() {
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<System, MemoryContext> =
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
GLOBAL.register_thread();
let _test = GLOBAL.scope(MemoryContext::Test);
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32]), 8) };
drop(_test);
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32, 1_i32]), 4) };
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
let mut text = BufferedTextEncoder::new();
GLOBAL.collect_group_into(&mut text).unwrap();
let text = String::from_utf8(text.finish().into()).unwrap();
assert_eq!(
text,
r#"# HELP deallocated_bytes total number of bytes deallocated
# TYPE deallocated_bytes counter
deallocated_bytes{memory_context="root"} 4
deallocated_bytes{memory_context="test"} 8
# HELP allocated_bytes total number of bytes allocated
# TYPE allocated_bytes counter
allocated_bytes{memory_context="root"} 4
allocated_bytes{memory_context="test"} 8
"#
);
}
}

View File

@@ -0,0 +1,581 @@
//! A vendoring of `thread_local` 1.1.9 with some changes needed for TLS destructors
// Copyright 2017 Amanieu d'Antras
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::cell::UnsafeCell;
use std::fmt;
use std::iter::FusedIterator;
use std::mem::MaybeUninit;
use std::panic::UnwindSafe;
use std::ptr;
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
// Use usize::BITS once it has stabilized and the MSRV has been bumped.
#[cfg(target_pointer_width = "16")]
const POINTER_WIDTH: u8 = 16;
#[cfg(target_pointer_width = "32")]
const POINTER_WIDTH: u8 = 32;
#[cfg(target_pointer_width = "64")]
const POINTER_WIDTH: u8 = 64;
/// The total number of buckets stored in each thread local.
/// All buckets combined can hold up to `usize::MAX - 1` entries.
const BUCKETS: usize = (POINTER_WIDTH - 1) as usize;
/// Thread-local variable wrapper
///
/// See the [module-level documentation](index.html) for more.
pub(crate) struct ThreadLocal<T: Send> {
/// The buckets in the thread local. The nth bucket contains `2^n`
/// elements. Each bucket is lazily allocated.
buckets: [AtomicPtr<Entry<T>>; BUCKETS],
/// The number of values in the thread local. This can be less than the real number of values,
/// but is never more.
values: AtomicUsize,
}
struct Entry<T> {
present: AtomicBool,
value: UnsafeCell<MaybeUninit<T>>,
}
impl<T> Drop for Entry<T> {
fn drop(&mut self) {
if *self.present.get_mut() {
// safety: If present is true, then this value is allocated
// and we cannot touch it ever again after this function
unsafe { ptr::drop_in_place((*self.value.get()).as_mut_ptr()) }
}
}
}
// Safety: ThreadLocal is always Sync, even if T isn't
unsafe impl<T: Send> Sync for ThreadLocal<T> {}
impl<T: Send> Default for ThreadLocal<T> {
fn default() -> ThreadLocal<T> {
ThreadLocal::new()
}
}
impl<T: Send> Drop for ThreadLocal<T> {
fn drop(&mut self) {
// Free each non-null bucket
for (i, bucket) in self.buckets.iter_mut().enumerate() {
let bucket_ptr = *bucket.get_mut();
let this_bucket_size = 1 << i;
if bucket_ptr.is_null() {
continue;
}
// Safety: the bucket_ptr is allocated and the bucket size is correct.
unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
}
}
}
impl<T: Send> ThreadLocal<T> {
/// Creates a new empty `ThreadLocal`.
pub const fn new() -> ThreadLocal<T> {
Self {
buckets: [const { AtomicPtr::new(ptr::null_mut()) }; BUCKETS],
values: AtomicUsize::new(0),
}
}
/// Returns the element for the current thread, if it exists.
pub fn get(&self) -> Option<&T> {
thread_id::get().and_then(|t| self.get_inner(t))
}
/// Returns the element for the current thread, or creates it if it doesn't
/// exist.
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> T,
{
let thread = thread_id::get_or_init();
if let Some(val) = self.get_inner(thread) {
return val;
}
self.insert(thread, create())
}
fn get_inner(&self, thread: thread_id::Thread) -> Option<&T> {
// Safety: There would need to be isize::MAX+1 threads for the bucket to overflow.
let bucket_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr = bucket_ptr.load(Ordering::Acquire);
if bucket_ptr.is_null() {
return None;
}
// Safety: the bucket always has enough capacity for this index.
// Safety: If present is true, then this entry is allocated and it is safe to read.
unsafe {
let entry = &*bucket_ptr.add(thread.index);
if entry.present.load(Ordering::Relaxed) {
Some(&*(&*entry.value.get()).as_ptr())
} else {
None
}
}
}
#[cold]
fn insert(&self, thread: thread_id::Thread, data: T) -> &T {
// Safety: There would need to be isize::MAX+1 threads for the bucket to overflow.
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
// If the bucket doesn't already exist, we need to allocate it
let bucket_ptr = if bucket_ptr.is_null() {
let new_bucket = allocate_bucket(thread.bucket_size());
match bucket_atomic_ptr.compare_exchange(
ptr::null_mut(),
new_bucket,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => new_bucket,
// If the bucket value changed (from null), that means
// another thread stored a new bucket before we could,
// and we can free our bucket and use that one instead
Err(bucket_ptr) => {
// Safety: the bucket_ptr is allocated and the bucket size is correct.
unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) }
bucket_ptr
}
}
} else {
bucket_ptr
};
// Insert the new element into the bucket
// Safety: the bucket always has enough capacity for this index.
let entry = unsafe { &*bucket_ptr.add(thread.index) };
let value_ptr = entry.value.get();
// Safety: present is false, so no other threads will be reading this,
// and it is owned by our thread, so no other threads will be writing to this.
unsafe { value_ptr.write(MaybeUninit::new(data)) };
entry.present.store(true, Ordering::Release);
self.values.fetch_add(1, Ordering::Release);
// Safety: present is true, so it is now safe to read.
unsafe { &*(&*value_ptr).as_ptr() }
}
/// Returns an iterator over the local values of all threads in unspecified
/// order.
///
/// This call can be done safely, as `T` is required to implement [`Sync`].
pub fn iter(&self) -> Iter<'_, T>
where
T: Sync,
{
Iter {
thread_local: self,
yielded: 0,
bucket: 0,
bucket_size: 1,
index: 0,
}
}
}
impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
type Item = &'a T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
}
}
impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
/// Iterator over the contents of a `ThreadLocal`.
#[derive(Debug)]
pub struct Iter<'a, T: Send + Sync> {
thread_local: &'a ThreadLocal<T>,
yielded: usize,
bucket: usize,
bucket_size: usize,
index: usize,
}
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
while self.bucket < BUCKETS {
let bucket = self.thread_local.buckets[self.bucket].load(Ordering::Acquire);
if !bucket.is_null() {
while self.index < self.bucket_size {
// Safety: the bucket always has enough capacity for this index.
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
self.yielded += 1;
// Safety: If present is true, then this entry is allocated and it is safe to read.
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
}
}
}
self.bucket_size <<= 1;
self.bucket += 1;
self.index = 0;
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let total = self.thread_local.values.load(Ordering::Acquire);
(total - self.yielded, None)
}
}
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
Box::into_raw(
(0..size)
.map(|_| Entry::<T> {
present: AtomicBool::new(false),
value: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect(),
) as *mut _
}
unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {
// Safety: caller must guarantee that bucket is allocated and of the correct size.
let _ = unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket, size)) };
}
mod thread_id {
// Copyright 2017 Amanieu d'Antras
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use super::POINTER_WIDTH;
use std::cell::Cell;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Mutex;
/// Thread ID manager which allocates thread IDs. It attempts to aggressively
/// reuse thread IDs where possible to avoid cases where a ThreadLocal grows
/// indefinitely when it is used by many short-lived threads.
struct ThreadIdManager {
free_from: usize,
free_list: Option<BinaryHeap<Reverse<usize>>>,
}
impl ThreadIdManager {
const fn new() -> Self {
Self {
free_from: 0,
free_list: None,
}
}
fn alloc(&mut self) -> usize {
if let Some(id) = self.free_list.as_mut().and_then(|heap| heap.pop()) {
id.0
} else {
// `free_from` can't overflow as each thread takes up at least 2 bytes of memory and
// thus we can't even have `usize::MAX / 2 + 1` threads.
let id = self.free_from;
self.free_from += 1;
id
}
}
fn free(&mut self, id: usize) {
self.free_list
.get_or_insert_with(BinaryHeap::new)
.push(Reverse(id));
}
}
static THREAD_ID_MANAGER: Mutex<ThreadIdManager> = Mutex::new(ThreadIdManager::new());
/// Data which is unique to the current thread while it is running.
/// A thread ID may be reused after a thread exits.
#[derive(Clone, Copy)]
pub(super) struct Thread {
/// The bucket this thread's local storage will be in.
pub(super) bucket: usize,
/// The index into the bucket this thread's local storage is in.
pub(super) index: usize,
}
impl Thread {
pub(super) fn new(id: usize) -> Self {
let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1;
let bucket_size = 1 << bucket;
let index = id - (bucket_size - 1);
Self { bucket, index }
}
/// The size of the bucket this thread's local storage will be in.
pub(super) fn bucket_size(&self) -> usize {
1 << self.bucket
}
}
// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller, and it is necessary for GlobalAlloc as we are not allowed
// to use thread locals with destructors during alloc.
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; }
// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard {
// We keep a copy of the thread ID in the ThreadGuard: we can't
// reliably access THREAD in our Drop impl due to the unpredictable
// order of TLS destructors.
id: Cell<usize>,
}
impl Drop for ThreadGuard {
fn drop(&mut self) {
// Release the thread ID. Any further accesses to the thread ID
// will go through get_slow which will either panic or
// initialize a new ThreadGuard.
let _ = THREAD.try_with(|thread| thread.set(None));
THREAD_ID_MANAGER.lock().unwrap().free(self.id.get());
}
}
/// Returns a thread ID for the current thread.
#[inline]
pub(crate) fn get() -> Option<Thread> {
THREAD.with(|thread| thread.get())
}
/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get_or_init() -> Thread {
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
get_slow(thread)
}
})
}
/// Out-of-line slow path for allocating a thread ID.
#[cold]
fn get_slow(thread: &Cell<Option<Thread>>) -> Thread {
let id = THREAD_ID_MANAGER.lock().unwrap().alloc();
let new = Thread::new(id);
thread.set(Some(new));
THREAD_GUARD.with(|guard| guard.id.set(id));
new
}
#[test]
fn test_thread() {
let thread = Thread::new(0);
assert_eq!(thread.bucket, 0);
assert_eq!(thread.bucket_size(), 1);
assert_eq!(thread.index, 0);
let thread = Thread::new(1);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 0);
let thread = Thread::new(2);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 1);
let thread = Thread::new(3);
assert_eq!(thread.bucket, 2);
assert_eq!(thread.bucket_size(), 4);
assert_eq!(thread.index, 0);
let thread = Thread::new(19);
assert_eq!(thread.bucket, 4);
assert_eq!(thread.bucket_size(), 16);
assert_eq!(thread.index, 4);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::thread;
fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
let count = AtomicUsize::new(0);
Arc::new(move || count.fetch_add(1, Relaxed))
}
#[test]
fn same_thread() {
let create = make_create();
let tls = ThreadLocal::new();
assert_eq!(None, tls.get());
assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
// tls.clear();
// assert_eq!(None, tls.get());
}
#[test]
fn different_thread() {
let create = make_create();
let tls = Arc::new(ThreadLocal::new());
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.get());
assert_eq!(1, *tls2.get_or(|| create2()));
assert_eq!(Some(&1), tls2.get());
})
.join()
.unwrap();
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
}
#[test]
fn iter() {
let tls = Arc::new(ThreadLocal::new());
tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
thread::spawn(move || {
tls2.get_or(|| Box::new(2));
let tls3 = tls2.clone();
thread::spawn(move || {
tls3.get_or(|| Box::new(3));
})
.join()
.unwrap();
drop(tls2);
})
.join()
.unwrap();
let tls = Arc::try_unwrap(tls).unwrap();
let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
}
#[test]
fn miri_iter_soundness_check() {
let tls = Arc::new(ThreadLocal::new());
let _local = tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
let join_1 = thread::spawn(move || {
let _tls = tls2.get_or(|| Box::new(2));
let iter = tls2.iter();
for item in iter {
println!("{item:?}");
}
});
let iter = tls.iter();
for item in iter {
println!("{item:?}");
}
join_1.join().ok();
}
#[test]
fn test_drop() {
let local = ThreadLocal::new();
struct Dropped(Arc<AtomicUsize>);
impl Drop for Dropped {
fn drop(&mut self) {
self.0.fetch_add(1, Relaxed);
}
}
let dropped = Arc::new(AtomicUsize::new(0));
local.get_or(|| Dropped(dropped.clone()));
assert_eq!(dropped.load(Relaxed), 0);
drop(local);
assert_eq!(dropped.load(Relaxed), 1);
}
#[test]
fn test_earlyreturn_buckets() {
struct Dropped(Arc<AtomicUsize>);
impl Drop for Dropped {
fn drop(&mut self) {
self.0.fetch_add(1, Relaxed);
}
}
let dropped = Arc::new(AtomicUsize::new(0));
// We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used.
// Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`.
let thread = thread_id::Thread::new(1234);
assert!(thread.bucket > 1);
let local = ThreadLocal::new();
local.insert(thread, Dropped(dropped.clone()));
let item = local.iter().next().unwrap();
assert_eq!(item.0.load(Relaxed), 0);
drop(local);
assert_eq!(dropped.load(Relaxed), 1);
}
#[test]
fn is_sync() {
fn foo<T: Sync>() {}
foo::<ThreadLocal<String>>();
foo::<ThreadLocal<RefCell<String>>>();
}
}

View File

@@ -478,7 +478,7 @@ pub trait CounterPairAssoc {
}
pub struct CounterPairVec<A: CounterPairAssoc> {
vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
pub vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
}
impl<A: CounterPairAssoc> Default for CounterPairVec<A>
@@ -492,6 +492,17 @@ where
}
}
impl<A: CounterPairAssoc> CounterPairVec<A>
where
A::LabelGroupSet: Default,
{
pub fn dense() -> Self {
Self {
vec: measured::metric::MetricVec::dense(),
}
}
}
impl<A: CounterPairAssoc> CounterPairVec<A> {
pub fn guard(
&self,
@@ -501,14 +512,27 @@ impl<A: CounterPairAssoc> CounterPairVec<A> {
self.vec.get_metric(id).inc.inc();
MeasuredCounterPairGuard { vec: &self.vec, id }
}
pub fn inc(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc();
}
pub fn dec(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).dec.inc();
}
pub fn inc_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc_by(x);
}
pub fn dec_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).dec.inc_by(x);
}
pub fn remove_metric(
&self,
labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>,
@@ -553,6 +577,24 @@ pub struct MeasuredCounterPairState {
pub dec: CounterState,
}
impl MeasuredCounterPairState {
pub fn inc(&self) {
self.inc.inc();
}
pub fn dec(&self) {
self.dec.inc();
}
pub fn inc_by(&self, x: u64) {
self.inc.inc_by(x);
}
pub fn dec_by(&self, x: u64) {
self.dec.inc_by(x);
}
}
impl measured::metric::MetricType for MeasuredCounterPairState {
type Metadata = ();
}

View File

@@ -10,6 +10,7 @@ testing = ["dep:tokio-postgres"]
[dependencies]
ahash.workspace = true
alloc-metrics.workspace = true
anyhow.workspace = true
arc-swap.workspace = true
async-compression.workspace = true

View File

@@ -1,11 +1,22 @@
use alloc_metrics::TrackedAllocator;
use proxy::binary::proxy::MemoryContext;
use tikv_jemallocator::Jemalloc;
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<Jemalloc, MemoryContext> =
unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) };
#[allow(non_upper_case_globals)]
#[unsafe(export_name = "malloc_conf")]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:21\0";
#[tokio::main]
async fn main() -> anyhow::Result<()> {
proxy::binary::proxy::run().await
fn main() -> anyhow::Result<()> {
GLOBAL.register_thread();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.on_thread_start(|| GLOBAL.register_thread())
.build()
.expect("Failed building the Runtime")
.block_on(proxy::binary::proxy::run(&GLOBAL))
}

View File

@@ -111,7 +111,7 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");

View File

@@ -80,7 +80,7 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
let args = cli().get_matches();
let destination: String = args

View File

@@ -39,7 +39,8 @@ use crate::config::{
};
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
pub use crate::metrics::MemoryContext;
use crate::metrics::{Alloc, Metrics};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
@@ -318,7 +319,7 @@ struct PgSniRouterArgs {
dest: Option<String>,
}
pub async fn run() -> anyhow::Result<()> {
pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> {
let _logging_guard = crate::logging::init().await?;
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
@@ -340,7 +341,7 @@ pub async fn run() -> anyhow::Result<()> {
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let config = build_config(&args, alloc)?;
let auth_backend = build_auth_backend(&args)?;
match auth_backend {
@@ -589,9 +590,12 @@ pub async fn run() -> anyhow::Result<()> {
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
fn build_config(
args: &ProxyCliArgs,
alloc: &'static Alloc,
) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
Metrics::install(thread_pool.metrics.clone(), Some(alloc));
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(

View File

@@ -1,5 +1,11 @@
#![expect(
clippy::ref_option_ref,
reason = "generated from measured derived output"
)]
use std::sync::{Arc, OnceLock};
use alloc_metrics::TrackedAllocator;
use lasso::ThreadedRodeo;
use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
@@ -11,26 +17,33 @@ use measured::{
MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec};
use tikv_jemallocator::Jemalloc;
use tokio::time::{self, Instant};
use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
pub type Alloc = TrackedAllocator<Jemalloc, MemoryContext>;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>, alloc: Option<&'static Alloc>))]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
pub proxy: ProxyMetrics,
#[metric(namespace = "alloc")]
#[metric(init = alloc)]
pub alloc: Option<&'static Alloc>,
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
pub fn install(thread_pool: Arc<ThreadPoolMetrics>, alloc: Option<&'static Alloc>) {
let mut metrics = Metrics::new(thread_pool, alloc);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
@@ -45,7 +58,7 @@ impl Metrics {
pub fn get() -> &'static Self {
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)), None));
#[cfg(not(test))]
SELF.get()
@@ -660,3 +673,9 @@ pub struct ThreadPoolMetrics {
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "context")]
pub enum MemoryContext {
Root,
}