diff --git a/.config/hakari.toml b/.config/hakari.toml index 9991cd92b0..77d6119223 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -30,6 +30,7 @@ workspace-members = [ "vm_monitor", # All of these exist in libs and are not usually built independently. # Putting workspace hack there adds a bottleneck for cargo builds. + "alloc-metrics", "compute_api", "consumption_metrics", "desim", diff --git a/Cargo.lock b/Cargo.lock index 137b883a6d..8d5b288f67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,14 @@ dependencies = [ "equator", ] +[[package]] +name = "alloc-metrics" +version = "0.1.0" +dependencies = [ + "measured", + "metrics", +] + [[package]] name = "allocator-api2" version = "0.2.16" @@ -5301,6 +5309,7 @@ name = "proxy" version = "0.1.0" dependencies = [ "ahash", + "alloc-metrics", "anyhow", "arc-swap", "assert-json-diff", diff --git a/Cargo.toml b/Cargo.toml index 6d91262882..012e3fe4f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -253,6 +253,7 @@ azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git" azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] } ## Local libraries +alloc-metrics = { version = "0.1", path = "./libs/alloc-metrics/" } compute_api = { version = "0.1", path = "./libs/compute_api/" } consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } desim = { version = "0.1", path = "./libs/desim" } diff --git a/libs/alloc-metrics/Cargo.toml b/libs/alloc-metrics/Cargo.toml new file mode 100644 index 0000000000..96e0ded8d8 --- /dev/null +++ b/libs/alloc-metrics/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "alloc-metrics" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +metrics.workspace = true +measured.workspace = true diff --git a/libs/alloc-metrics/src/lib.rs b/libs/alloc-metrics/src/lib.rs new file mode 100644 index 0000000000..21725f7816 --- /dev/null +++ b/libs/alloc-metrics/src/lib.rs @@ -0,0 +1,394 @@ +//! Tagged allocator measurements. + +mod thread_local; + +use std::{ + alloc::{GlobalAlloc, Layout}, + cell::Cell, + marker::PhantomData, + sync::{ + OnceLock, + atomic::{AtomicU64, Ordering::Relaxed}, + }, +}; + +use measured::{ + FixedCardinalityLabel, LabelGroup, MetricGroup, + label::StaticLabelSet, + metric::{MetricEncoding, counter::CounterState, group::Encoding, name::MetricName}, +}; +use metrics::{CounterPairAssoc, CounterPairVec, MeasuredCounterPairState}; +use thread_local::ThreadLocal; + +type AllocCounter = CounterPairVec>; + +pub struct TrackedAllocator { + inner: A, + + /// potentially high-content fallback if the thread was not registered. + default_counters: MeasuredCounterPairState, + /// Default tag to use if this thread is not registered. + default_tag: T, + + /// Current memory context for this thread. + thread_scope: OnceLock>>, + /// per thread state containing low contention counters for faster allocations. + thread_state: OnceLock>>, + + /// where thread alloc data is eventually saved to, even if threads are shutdown. + global: OnceLock>, +} + +impl TrackedAllocator +where + T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup, +{ + /// # Safety + /// + /// [`FixedCardinalityLabel`] must be implemented correctly, fully dense, and must not panic. + pub const unsafe fn new(alloc: A, default: T) -> Self { + TrackedAllocator { + inner: alloc, + default_tag: default, + default_counters: MeasuredCounterPairState { + inc: CounterState { + count: AtomicU64::new(0), + }, + dec: CounterState { + count: AtomicU64::new(0), + }, + }, + thread_scope: OnceLock::new(), + thread_state: OnceLock::new(), + global: OnceLock::new(), + } + } + + /// Allocations + pub fn register_thread(&'static self) { + self.register_thread_inner(); + } + + pub fn scope(&'static self, tag: T) -> AllocScope<'static, T> { + let cell = self.register_thread_inner(); + let last = cell.replace(tag); + AllocScope { cell, last } + } + + fn register_thread_inner(&'static self) -> &'static Cell { + self.thread_state + .get_or_init(ThreadLocal::new) + .get_or(|| ThreadState { + counters: CounterPairVec::dense(), + global: self.global.get_or_init(CounterPairVec::dense), + }); + + self.thread_scope + .get_or_init(ThreadLocal::new) + .get_or(|| Cell::new(self.default_tag)) + } + + fn current_counters_alloc_safe(&self) -> Option<&AllocCounter> { + // We are being very careful here to not allocate or panic. + self.thread_state + .get() + .and_then(ThreadLocal::get) + .map(|s| &s.counters) + .or_else(|| self.global.get()) + } + + fn current_tag_alloc_safe(&self) -> T { + // We are being very careful here to not allocate or panic. + self.thread_scope + .get() + .and_then(ThreadLocal::get) + .map_or(self.default_tag, Cell::get) + } +} + +impl TrackedAllocator +where + T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup, +{ + unsafe fn alloc_inner(&self, layout: Layout, alloc: impl FnOnce(Layout) -> *mut u8) -> *mut u8 { + let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::()) else { + return std::ptr::null_mut(); + }; + let tagged_layout = tagged_layout.pad_to_align(); + + // Safety: The layout is not zero-sized. + let ptr = alloc(tagged_layout); + + // allocation failed. + if ptr.is_null() { + return ptr; + } + + let tag = self.current_tag_alloc_safe(); + + // Allocation successful. Write our tag + // Safety: tag_offset is inbounds of the ptr + unsafe { ptr.add(tag_offset).cast::().write(tag) } + + if let Some(counters) = self.current_counters_alloc_safe() { + // During `Self::new`, the caller has guaranteed that tag encoding will not panic. + counters.inc_by(tag, layout.size() as u64); + } else { + self.default_counters.inc_by(layout.size() as u64); + } + + ptr + } +} + +// We will tag our allocation by adding `T` to the end of the layout. +// This is ok only as long as it does not overflow. If it does, we will +// just fail the allocation by returning null. +// +// Safety: we will not unwind during alloc, and we will ensure layouts are handled correctly. +unsafe impl GlobalAlloc for TrackedAllocator +where + A: GlobalAlloc, + T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup, +{ + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + // safety: same as caller + unsafe { self.alloc_inner(layout, |tagged_layout| self.inner.alloc(tagged_layout)) } + } + + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + // safety: same as caller + unsafe { + self.alloc_inner(layout, |tagged_layout| { + self.inner.alloc_zeroed(tagged_layout) + }) + } + } + + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + // SAFETY: the caller must ensure that the `new_size` does not overflow. + // `layout.align()` comes from a `Layout` and is thus guaranteed to be valid. + let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) }; + + let Ok((new_tagged_layout, new_tag_offset)) = new_layout.extend(Layout::new::()) else { + return std::ptr::null_mut(); + }; + let new_tagged_layout = new_tagged_layout.pad_to_align(); + + let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::()) else { + // Safety: This layout clearly did not match what was originally allocated, + // otherwise alloc() would have caught this error and returned null. + unsafe { std::hint::unreachable_unchecked() } + }; + let tagged_layout = tagged_layout.pad_to_align(); + + // get the tag set during alloc + // Safety: tag_offset is inbounds of the ptr + let tag = unsafe { ptr.add(tag_offset).cast::().read() }; + + // Safety: layout sizes are correct + let new_ptr = unsafe { + self.inner + .realloc(ptr, tagged_layout, new_tagged_layout.size()) + }; + + // allocation failed. + if new_ptr.is_null() { + return new_ptr; + } + + let new_tag = self.current_tag_alloc_safe(); + + // Allocation successful. Write our tag + // Safety: new_tag_offset is inbounds of the ptr + unsafe { new_ptr.add(new_tag_offset).cast::().write(new_tag) } + + if let Some(counters) = self.current_counters_alloc_safe() { + if tag.encode() == new_tag.encode() { + let diff = usize::abs_diff(new_layout.size(), layout.size()) as u64; + if new_layout.size() > layout.size() { + counters.inc_by(tag, diff); + } else { + counters.dec_by(tag, diff); + } + } else { + counters.inc_by(new_tag, new_layout.size() as u64); + counters.dec_by(tag, layout.size() as u64); + } + } else { + // no tag was registered at all, therefore both tags must be default. + let diff = usize::abs_diff(new_layout.size(), layout.size()) as u64; + if new_layout.size() > layout.size() { + self.default_counters.inc_by(diff); + } else { + self.default_counters.dec_by(diff); + } + } + + new_ptr + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::()) else { + // Safety: This layout clearly did not match what was originally allocated, + // otherwise alloc() would have caught this error and returned null. + unsafe { std::hint::unreachable_unchecked() } + }; + let tagged_layout = tagged_layout.pad_to_align(); + + // get the tag set during alloc + // Safety: tag_offset is inbounds of the ptr + let tag = unsafe { ptr.add(tag_offset).cast::().read() }; + + if let Some(counters) = self.current_counters_alloc_safe() { + counters.dec_by(tag, layout.size() as u64); + } else { + // if tag is not default, then global would have been registered, + // therefore tag must be default. + self.default_counters.dec_by(layout.size() as u64); + } + + // Safety: caller upholds contract for us + unsafe { self.inner.dealloc(ptr, tagged_layout) } + } +} + +pub struct AllocScope<'a, T: FixedCardinalityLabel> { + cell: &'a Cell, + last: T, +} + +impl<'a, T: FixedCardinalityLabel> Drop for AllocScope<'a, T> { + fn drop(&mut self) { + self.cell.set(self.last); + } +} + +struct AllocPair(PhantomData); + +impl CounterPairAssoc for AllocPair { + const INC_NAME: &'static MetricName = MetricName::from_str("allocated_bytes"); + const DEC_NAME: &'static MetricName = MetricName::from_str("deallocated_bytes"); + + const INC_HELP: &'static str = "total number of bytes allocated"; + const DEC_HELP: &'static str = "total number of bytes deallocated"; + + type LabelGroupSet = StaticLabelSet; +} + +struct ThreadState { + counters: AllocCounter, + global: &'static AllocCounter, +} + +// Ensure the counters are measured on thread destruction. +impl Drop for ThreadState { + fn drop(&mut self) { + // iterate over all labels + for tag in (0..T::cardinality()).map(T::decode) { + // load and reset the counts in the thread-local counters. + let id = self.counters.vec.with_labels(tag); + let mut m = self.counters.vec.get_metric_mut(id); + let inc = *m.inc.count.get_mut(); + let dec = *m.dec.count.get_mut(); + + // add the counts into the global counters. + let id = self.global.vec.with_labels(tag); + let m = self.global.vec.get_metric(id); + m.inc.count.fetch_add(inc, Relaxed); + m.dec.count.fetch_add(dec, Relaxed); + } + } +} + +impl MetricGroup for TrackedAllocator +where + T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup, + Enc: Encoding, + CounterState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + let global = self.global.get_or_init(CounterPairVec::dense); + + // iterate over all counter threads + for s in self.thread_state.get().into_iter().flat_map(|s| s.iter()) { + // iterate over all labels + for tag in (0..T::cardinality()).map(T::decode) { + let id = s.counters.vec.with_labels(tag); + sample(global, &s.counters.vec.get_metric(id), tag); + } + } + + sample(global, &self.default_counters, self.default_tag); + + global.collect_group_into(enc) + } +} + +fn sample( + global: &AllocCounter, + local: &MeasuredCounterPairState, + tag: T, +) { + // load and reset the counts in the thread-local counters. + let inc = local.inc.count.swap(0, Relaxed); + let dec = local.dec.count.swap(0, Relaxed); + + // add the counts into the global counters. + let id = global.vec.with_labels(tag); + let m = global.vec.get_metric(id); + m.inc.count.fetch_add(inc, Relaxed); + m.dec.count.fetch_add(dec, Relaxed); +} + +#[cfg(test)] +mod tests { + use std::alloc::{GlobalAlloc, Layout, System}; + + use measured::{FixedCardinalityLabel, MetricGroup, text::BufferedTextEncoder}; + + use crate::TrackedAllocator; + + #[derive(FixedCardinalityLabel, Clone, Copy, Debug)] + #[label(singleton = "memory_context")] + pub enum MemoryContext { + Root, + Test, + } + + #[test] + fn alloc() { + // Safety: `MemoryContext` upholds the safety requirements. + static GLOBAL: TrackedAllocator = + unsafe { TrackedAllocator::new(System, MemoryContext::Root) }; + + GLOBAL.register_thread(); + + let _test = GLOBAL.scope(MemoryContext::Test); + + let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) }; + let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32]), 8) }; + + drop(_test); + + let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32, 1_i32]), 4) }; + unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) }; + + let mut text = BufferedTextEncoder::new(); + GLOBAL.collect_group_into(&mut text).unwrap(); + let text = String::from_utf8(text.finish().into()).unwrap(); + assert_eq!( + text, + r#"# HELP deallocated_bytes total number of bytes deallocated +# TYPE deallocated_bytes counter +deallocated_bytes{memory_context="root"} 4 +deallocated_bytes{memory_context="test"} 8 + +# HELP allocated_bytes total number of bytes allocated +# TYPE allocated_bytes counter +allocated_bytes{memory_context="root"} 4 +allocated_bytes{memory_context="test"} 8 +"# + ); + } +} diff --git a/libs/alloc-metrics/src/thread_local.rs b/libs/alloc-metrics/src/thread_local.rs new file mode 100644 index 0000000000..49c37b1b50 --- /dev/null +++ b/libs/alloc-metrics/src/thread_local.rs @@ -0,0 +1,581 @@ +//! A vendoring of `thread_local` 1.1.9 with some changes needed for TLS destructors + +// Copyright 2017 Amanieu d'Antras +// +// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use std::cell::UnsafeCell; +use std::fmt; +use std::iter::FusedIterator; +use std::mem::MaybeUninit; +use std::panic::UnwindSafe; +use std::ptr; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; + +// Use usize::BITS once it has stabilized and the MSRV has been bumped. +#[cfg(target_pointer_width = "16")] +const POINTER_WIDTH: u8 = 16; +#[cfg(target_pointer_width = "32")] +const POINTER_WIDTH: u8 = 32; +#[cfg(target_pointer_width = "64")] +const POINTER_WIDTH: u8 = 64; + +/// The total number of buckets stored in each thread local. +/// All buckets combined can hold up to `usize::MAX - 1` entries. +const BUCKETS: usize = (POINTER_WIDTH - 1) as usize; + +/// Thread-local variable wrapper +/// +/// See the [module-level documentation](index.html) for more. +pub(crate) struct ThreadLocal { + /// The buckets in the thread local. The nth bucket contains `2^n` + /// elements. Each bucket is lazily allocated. + buckets: [AtomicPtr>; BUCKETS], + + /// The number of values in the thread local. This can be less than the real number of values, + /// but is never more. + values: AtomicUsize, +} + +struct Entry { + present: AtomicBool, + value: UnsafeCell>, +} + +impl Drop for Entry { + fn drop(&mut self) { + if *self.present.get_mut() { + // safety: If present is true, then this value is allocated + // and we cannot touch it ever again after this function + unsafe { ptr::drop_in_place((*self.value.get()).as_mut_ptr()) } + } + } +} + +// Safety: ThreadLocal is always Sync, even if T isn't +unsafe impl Sync for ThreadLocal {} + +impl Default for ThreadLocal { + fn default() -> ThreadLocal { + ThreadLocal::new() + } +} + +impl Drop for ThreadLocal { + fn drop(&mut self) { + // Free each non-null bucket + for (i, bucket) in self.buckets.iter_mut().enumerate() { + let bucket_ptr = *bucket.get_mut(); + + let this_bucket_size = 1 << i; + + if bucket_ptr.is_null() { + continue; + } + + // Safety: the bucket_ptr is allocated and the bucket size is correct. + unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) }; + } + } +} + +impl ThreadLocal { + /// Creates a new empty `ThreadLocal`. + pub const fn new() -> ThreadLocal { + Self { + buckets: [const { AtomicPtr::new(ptr::null_mut()) }; BUCKETS], + values: AtomicUsize::new(0), + } + } + + /// Returns the element for the current thread, if it exists. + pub fn get(&self) -> Option<&T> { + thread_id::get().and_then(|t| self.get_inner(t)) + } + + /// Returns the element for the current thread, or creates it if it doesn't + /// exist. + pub fn get_or(&self, create: F) -> &T + where + F: FnOnce() -> T, + { + let thread = thread_id::get_or_init(); + if let Some(val) = self.get_inner(thread) { + return val; + } + + self.insert(thread, create()) + } + + fn get_inner(&self, thread: thread_id::Thread) -> Option<&T> { + // Safety: There would need to be isize::MAX+1 threads for the bucket to overflow. + let bucket_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; + let bucket_ptr = bucket_ptr.load(Ordering::Acquire); + if bucket_ptr.is_null() { + return None; + } + // Safety: the bucket always has enough capacity for this index. + // Safety: If present is true, then this entry is allocated and it is safe to read. + unsafe { + let entry = &*bucket_ptr.add(thread.index); + if entry.present.load(Ordering::Relaxed) { + Some(&*(&*entry.value.get()).as_ptr()) + } else { + None + } + } + } + + #[cold] + fn insert(&self, thread: thread_id::Thread, data: T) -> &T { + // Safety: There would need to be isize::MAX+1 threads for the bucket to overflow. + let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; + let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire); + + // If the bucket doesn't already exist, we need to allocate it + let bucket_ptr = if bucket_ptr.is_null() { + let new_bucket = allocate_bucket(thread.bucket_size()); + + match bucket_atomic_ptr.compare_exchange( + ptr::null_mut(), + new_bucket, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => new_bucket, + // If the bucket value changed (from null), that means + // another thread stored a new bucket before we could, + // and we can free our bucket and use that one instead + Err(bucket_ptr) => { + // Safety: the bucket_ptr is allocated and the bucket size is correct. + unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) } + bucket_ptr + } + } + } else { + bucket_ptr + }; + + // Insert the new element into the bucket + // Safety: the bucket always has enough capacity for this index. + let entry = unsafe { &*bucket_ptr.add(thread.index) }; + let value_ptr = entry.value.get(); + // Safety: present is false, so no other threads will be reading this, + // and it is owned by our thread, so no other threads will be writing to this. + unsafe { value_ptr.write(MaybeUninit::new(data)) }; + entry.present.store(true, Ordering::Release); + + self.values.fetch_add(1, Ordering::Release); + + // Safety: present is true, so it is now safe to read. + unsafe { &*(&*value_ptr).as_ptr() } + } + + /// Returns an iterator over the local values of all threads in unspecified + /// order. + /// + /// This call can be done safely, as `T` is required to implement [`Sync`]. + pub fn iter(&self) -> Iter<'_, T> + where + T: Sync, + { + Iter { + thread_local: self, + yielded: 0, + bucket: 0, + bucket_size: 1, + index: 0, + } + } +} + +impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal { + type Item = &'a T; + type IntoIter = Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl fmt::Debug for ThreadLocal { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) + } +} + +impl UnwindSafe for ThreadLocal {} + +/// Iterator over the contents of a `ThreadLocal`. +#[derive(Debug)] +pub struct Iter<'a, T: Send + Sync> { + thread_local: &'a ThreadLocal, + yielded: usize, + bucket: usize, + bucket_size: usize, + index: usize, +} + +impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option { + while self.bucket < BUCKETS { + let bucket = self.thread_local.buckets[self.bucket].load(Ordering::Acquire); + + if !bucket.is_null() { + while self.index < self.bucket_size { + // Safety: the bucket always has enough capacity for this index. + let entry = unsafe { &*bucket.add(self.index) }; + self.index += 1; + if entry.present.load(Ordering::Acquire) { + self.yielded += 1; + // Safety: If present is true, then this entry is allocated and it is safe to read. + return Some(unsafe { &*(&*entry.value.get()).as_ptr() }); + } + } + } + + self.bucket_size <<= 1; + self.bucket += 1; + self.index = 0; + } + + None + } + + fn size_hint(&self) -> (usize, Option) { + let total = self.thread_local.values.load(Ordering::Acquire); + (total - self.yielded, None) + } +} +impl FusedIterator for Iter<'_, T> {} + +fn allocate_bucket(size: usize) -> *mut Entry { + Box::into_raw( + (0..size) + .map(|_| Entry:: { + present: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + }) + .collect(), + ) as *mut _ +} + +unsafe fn deallocate_bucket(bucket: *mut Entry, size: usize) { + // Safety: caller must guarantee that bucket is allocated and of the correct size. + let _ = unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket, size)) }; +} + +mod thread_id { + // Copyright 2017 Amanieu d'Antras + // + // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be + // copied, modified, or distributed except according to those terms. + + use super::POINTER_WIDTH; + use std::cell::Cell; + use std::cmp::Reverse; + use std::collections::BinaryHeap; + use std::sync::Mutex; + + /// Thread ID manager which allocates thread IDs. It attempts to aggressively + /// reuse thread IDs where possible to avoid cases where a ThreadLocal grows + /// indefinitely when it is used by many short-lived threads. + struct ThreadIdManager { + free_from: usize, + free_list: Option>>, + } + + impl ThreadIdManager { + const fn new() -> Self { + Self { + free_from: 0, + free_list: None, + } + } + + fn alloc(&mut self) -> usize { + if let Some(id) = self.free_list.as_mut().and_then(|heap| heap.pop()) { + id.0 + } else { + // `free_from` can't overflow as each thread takes up at least 2 bytes of memory and + // thus we can't even have `usize::MAX / 2 + 1` threads. + + let id = self.free_from; + self.free_from += 1; + id + } + } + + fn free(&mut self, id: usize) { + self.free_list + .get_or_insert_with(BinaryHeap::new) + .push(Reverse(id)); + } + } + + static THREAD_ID_MANAGER: Mutex = Mutex::new(ThreadIdManager::new()); + + /// Data which is unique to the current thread while it is running. + /// A thread ID may be reused after a thread exits. + #[derive(Clone, Copy)] + pub(super) struct Thread { + /// The bucket this thread's local storage will be in. + pub(super) bucket: usize, + /// The index into the bucket this thread's local storage is in. + pub(super) index: usize, + } + + impl Thread { + pub(super) fn new(id: usize) -> Self { + let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1; + let bucket_size = 1 << bucket; + let index = id - (bucket_size - 1); + + Self { bucket, index } + } + + /// The size of the bucket this thread's local storage will be in. + pub(super) fn bucket_size(&self) -> usize { + 1 << self.bucket + } + } + + // This is split into 2 thread-local variables so that we can check whether the + // thread is initialized without having to register a thread-local destructor. + // + // This makes the fast path smaller, and it is necessary for GlobalAlloc as we are not allowed + // to use thread locals with destructors during alloc. + thread_local! { static THREAD: Cell> = const { Cell::new(None) }; } + thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; } + + // Guard to ensure the thread ID is released on thread exit. + struct ThreadGuard { + // We keep a copy of the thread ID in the ThreadGuard: we can't + // reliably access THREAD in our Drop impl due to the unpredictable + // order of TLS destructors. + id: Cell, + } + + impl Drop for ThreadGuard { + fn drop(&mut self) { + // Release the thread ID. Any further accesses to the thread ID + // will go through get_slow which will either panic or + // initialize a new ThreadGuard. + let _ = THREAD.try_with(|thread| thread.set(None)); + THREAD_ID_MANAGER.lock().unwrap().free(self.id.get()); + } + } + + /// Returns a thread ID for the current thread. + #[inline] + pub(crate) fn get() -> Option { + THREAD.with(|thread| thread.get()) + } + + /// Returns a thread ID for the current thread, allocating one if needed. + #[inline] + pub(crate) fn get_or_init() -> Thread { + THREAD.with(|thread| { + if let Some(thread) = thread.get() { + thread + } else { + get_slow(thread) + } + }) + } + + /// Out-of-line slow path for allocating a thread ID. + #[cold] + fn get_slow(thread: &Cell>) -> Thread { + let id = THREAD_ID_MANAGER.lock().unwrap().alloc(); + let new = Thread::new(id); + thread.set(Some(new)); + THREAD_GUARD.with(|guard| guard.id.set(id)); + new + } + + #[test] + fn test_thread() { + let thread = Thread::new(0); + assert_eq!(thread.bucket, 0); + assert_eq!(thread.bucket_size(), 1); + assert_eq!(thread.index, 0); + + let thread = Thread::new(1); + assert_eq!(thread.bucket, 1); + assert_eq!(thread.bucket_size(), 2); + assert_eq!(thread.index, 0); + + let thread = Thread::new(2); + assert_eq!(thread.bucket, 1); + assert_eq!(thread.bucket_size(), 2); + assert_eq!(thread.index, 1); + + let thread = Thread::new(3); + assert_eq!(thread.bucket, 2); + assert_eq!(thread.bucket_size(), 4); + assert_eq!(thread.index, 0); + + let thread = Thread::new(19); + assert_eq!(thread.bucket, 4); + assert_eq!(thread.bucket_size(), 16); + assert_eq!(thread.index, 4); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::cell::RefCell; + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering::Relaxed; + use std::thread; + + fn make_create() -> Arc usize + Send + Sync> { + let count = AtomicUsize::new(0); + Arc::new(move || count.fetch_add(1, Relaxed)) + } + + #[test] + fn same_thread() { + let create = make_create(); + let tls = ThreadLocal::new(); + assert_eq!(None, tls.get()); + assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); + // tls.clear(); + // assert_eq!(None, tls.get()); + } + + #[test] + fn different_thread() { + let create = make_create(); + let tls = Arc::new(ThreadLocal::new()); + assert_eq!(None, tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + + let tls2 = tls.clone(); + let create2 = create.clone(); + thread::spawn(move || { + assert_eq!(None, tls2.get()); + assert_eq!(1, *tls2.get_or(|| create2())); + assert_eq!(Some(&1), tls2.get()); + }) + .join() + .unwrap(); + + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + } + + #[test] + fn iter() { + let tls = Arc::new(ThreadLocal::new()); + tls.get_or(|| Box::new(1)); + + let tls2 = tls.clone(); + thread::spawn(move || { + tls2.get_or(|| Box::new(2)); + let tls3 = tls2.clone(); + thread::spawn(move || { + tls3.get_or(|| Box::new(3)); + }) + .join() + .unwrap(); + drop(tls2); + }) + .join() + .unwrap(); + + let tls = Arc::try_unwrap(tls).unwrap(); + + let mut v = tls.iter().map(|x| **x).collect::>(); + v.sort_unstable(); + assert_eq!(vec![1, 2, 3], v); + } + + #[test] + fn miri_iter_soundness_check() { + let tls = Arc::new(ThreadLocal::new()); + let _local = tls.get_or(|| Box::new(1)); + + let tls2 = tls.clone(); + let join_1 = thread::spawn(move || { + let _tls = tls2.get_or(|| Box::new(2)); + let iter = tls2.iter(); + for item in iter { + println!("{item:?}"); + } + }); + + let iter = tls.iter(); + for item in iter { + println!("{item:?}"); + } + + join_1.join().ok(); + } + + #[test] + fn test_drop() { + let local = ThreadLocal::new(); + struct Dropped(Arc); + impl Drop for Dropped { + fn drop(&mut self) { + self.0.fetch_add(1, Relaxed); + } + } + + let dropped = Arc::new(AtomicUsize::new(0)); + local.get_or(|| Dropped(dropped.clone())); + assert_eq!(dropped.load(Relaxed), 0); + drop(local); + assert_eq!(dropped.load(Relaxed), 1); + } + + #[test] + fn test_earlyreturn_buckets() { + struct Dropped(Arc); + impl Drop for Dropped { + fn drop(&mut self) { + self.0.fetch_add(1, Relaxed); + } + } + let dropped = Arc::new(AtomicUsize::new(0)); + + // We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used. + // Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`. + let thread = thread_id::Thread::new(1234); + assert!(thread.bucket > 1); + + let local = ThreadLocal::new(); + local.insert(thread, Dropped(dropped.clone())); + + let item = local.iter().next().unwrap(); + assert_eq!(item.0.load(Relaxed), 0); + drop(local); + assert_eq!(dropped.load(Relaxed), 1); + } + + #[test] + fn is_sync() { + fn foo() {} + foo::>(); + foo::>>(); + } +} diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 41873cdcd6..517577a6c3 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -478,7 +478,7 @@ pub trait CounterPairAssoc { } pub struct CounterPairVec { - vec: measured::metric::MetricVec, + pub vec: measured::metric::MetricVec, } impl Default for CounterPairVec @@ -492,6 +492,17 @@ where } } +impl CounterPairVec +where + A::LabelGroupSet: Default, +{ + pub fn dense() -> Self { + Self { + vec: measured::metric::MetricVec::dense(), + } + } +} + impl CounterPairVec { pub fn guard( &self, @@ -501,14 +512,27 @@ impl CounterPairVec { self.vec.get_metric(id).inc.inc(); MeasuredCounterPairGuard { vec: &self.vec, id } } + pub fn inc(&self, labels: ::Group<'_>) { let id = self.vec.with_labels(labels); self.vec.get_metric(id).inc.inc(); } + pub fn dec(&self, labels: ::Group<'_>) { let id = self.vec.with_labels(labels); self.vec.get_metric(id).dec.inc(); } + + pub fn inc_by(&self, labels: ::Group<'_>, x: u64) { + let id = self.vec.with_labels(labels); + self.vec.get_metric(id).inc.inc_by(x); + } + + pub fn dec_by(&self, labels: ::Group<'_>, x: u64) { + let id = self.vec.with_labels(labels); + self.vec.get_metric(id).dec.inc_by(x); + } + pub fn remove_metric( &self, labels: ::Group<'_>, @@ -553,6 +577,24 @@ pub struct MeasuredCounterPairState { pub dec: CounterState, } +impl MeasuredCounterPairState { + pub fn inc(&self) { + self.inc.inc(); + } + + pub fn dec(&self) { + self.dec.inc(); + } + + pub fn inc_by(&self, x: u64) { + self.inc.inc_by(x); + } + + pub fn dec_by(&self, x: u64) { + self.dec.inc_by(x); + } +} + impl measured::metric::MetricType for MeasuredCounterPairState { type Metadata = (); } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 82fe6818e3..36a397c1fa 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -10,6 +10,7 @@ testing = ["dep:tokio-postgres"] [dependencies] ahash.workspace = true +alloc-metrics.workspace = true anyhow.workspace = true arc-swap.workspace = true async-compression.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d60d32eb3b..d1bc0e21b5 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,11 +1,22 @@ +use alloc_metrics::TrackedAllocator; +use proxy::binary::proxy::MemoryContext; +use tikv_jemallocator::Jemalloc; + #[global_allocator] -static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +// Safety: `MemoryContext` upholds the safety requirements. +static GLOBAL: TrackedAllocator = + unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) }; #[allow(non_upper_case_globals)] #[unsafe(export_name = "malloc_conf")] pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:21\0"; -#[tokio::main] -async fn main() -> anyhow::Result<()> { - proxy::binary::proxy::run().await +fn main() -> anyhow::Result<()> { + GLOBAL.register_thread(); + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .on_thread_start(|| GLOBAL.register_thread()) + .build() + .expect("Failed building the Runtime") + .block_on(proxy::binary::proxy::run(&GLOBAL)) } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 401203d48c..ba3f596492 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -111,7 +111,7 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); + Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None); // TODO: refactor these to use labels debug!("Version: {GIT_VERSION}"); diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 4ac8b6a995..cfb1243d90 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -80,7 +80,7 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); + Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None); let args = cli().get_matches(); let destination: String = args diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 16a7dc7b67..720f90b04b 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -39,7 +39,8 @@ use crate::config::{ }; use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; -use crate::metrics::Metrics; +pub use crate::metrics::MemoryContext; +use crate::metrics::{Alloc, Metrics}; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; @@ -318,7 +319,7 @@ struct PgSniRouterArgs { dest: Option, } -pub async fn run() -> anyhow::Result<()> { +pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> { let _logging_guard = crate::logging::init().await?; let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); @@ -340,7 +341,7 @@ pub async fn run() -> anyhow::Result<()> { }; let args = ProxyCliArgs::parse(); - let config = build_config(&args)?; + let config = build_config(&args, alloc)?; let auth_backend = build_auth_backend(&args)?; match auth_backend { @@ -589,9 +590,12 @@ pub async fn run() -> anyhow::Result<()> { } /// ProxyConfig is created at proxy startup, and lives forever. -fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { +fn build_config( + args: &ProxyCliArgs, + alloc: &'static Alloc, +) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); - Metrics::install(thread_pool.metrics.clone()); + Metrics::install(thread_pool.metrics.clone(), Some(alloc)); let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 916604e2ec..61a3f1d7e3 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,5 +1,11 @@ +#![expect( + clippy::ref_option_ref, + reason = "generated from measured derived output" +)] + use std::sync::{Arc, OnceLock}; +use alloc_metrics::TrackedAllocator; use lasso::ThreadedRodeo; use measured::label::{ FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet, @@ -11,26 +17,33 @@ use measured::{ MetricGroup, }; use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec}; +use tikv_jemallocator::Jemalloc; use tokio::time::{self, Instant}; use crate::control_plane::messages::ColdStartInfo; use crate::error::ErrorKind; +pub type Alloc = TrackedAllocator; + #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new(thread_pool: Arc, alloc: Option<&'static Alloc>))] pub struct Metrics { #[metric(namespace = "proxy")] #[metric(init = ProxyMetrics::new(thread_pool))] pub proxy: ProxyMetrics, + #[metric(namespace = "alloc")] + #[metric(init = alloc)] + pub alloc: Option<&'static Alloc>, + #[metric(namespace = "wake_compute_lock")] pub wake_compute_lock: ApiLockMetrics, } static SELF: OnceLock = OnceLock::new(); impl Metrics { - pub fn install(thread_pool: Arc) { - let mut metrics = Metrics::new(thread_pool); + pub fn install(thread_pool: Arc, alloc: Option<&'static Alloc>) { + let mut metrics = Metrics::new(thread_pool, alloc); metrics.proxy.errors_total.init_all_dense(); metrics.proxy.redis_errors_total.init_all_dense(); @@ -45,7 +58,7 @@ impl Metrics { pub fn get() -> &'static Self { #[cfg(test)] - return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)), None)); #[cfg(not(test))] SELF.get() @@ -660,3 +673,9 @@ pub struct ThreadPoolMetrics { #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] pub worker_task_skips_total: CounterVec, } + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +#[label(singleton = "context")] +pub enum MemoryContext { + Root, +}