From b33047df7eb9f5d67c4ded02bab1d15fc53a3139 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 21 Jul 2025 10:09:10 +0100 Subject: [PATCH] cleanup code a little --- Cargo.lock | 7 +- Cargo.toml | 4 + libs/alloc-metrics/Cargo.toml | 1 + libs/alloc-metrics/src/counters.rs | 48 ++ libs/alloc-metrics/src/lib.rs | 73 +++- libs/alloc-metrics/src/metric_vec.rs | 137 +----- libs/alloc-metrics/src/thread_local.rs | 581 ------------------------- libs/metrics/src/lib.rs | 4 +- 8 files changed, 125 insertions(+), 730 deletions(-) create mode 100644 libs/alloc-metrics/src/counters.rs delete mode 100644 libs/alloc-metrics/src/thread_local.rs diff --git a/Cargo.lock b/Cargo.lock index f22732e4d4..6c19fd37df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,7 @@ dependencies = [ "criterion", "measured", "metrics", + "thread_local", "tikv-jemallocator", ] @@ -7343,12 +7344,10 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +version = "1.1.9" +source = "git+https://github.com/conradludgate/thread_local-rs?branch=no-tls-destructor-get#f9ca3d375745c14a632ae3ffe6a7a646dc8421a0" dependencies = [ "cfg-if", - "once_cell", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 012e3fe4f4..66986bed1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -195,6 +195,7 @@ sync_wrapper = "0.1.2" tar = "0.4" test-context = "0.3" thiserror = "1.0" +thread_local = "1.1.9" tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] } tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] } tokio = { version = "1.43.1", features = ["macros"] } @@ -303,6 +304,9 @@ tonic-build = "0.13.1" # Needed to get `tokio-postgres-rustls` to depend on our fork. tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" } +# Needed to fix a bug in alloc-metrics +thread_local = { git = "https://github.com/conradludgate/thread_local-rs", branch = "no-tls-destructor-get" } + ################# Binary contents sections [profile.release] diff --git a/libs/alloc-metrics/Cargo.toml b/libs/alloc-metrics/Cargo.toml index c682ec4e43..9ea001249d 100644 --- a/libs/alloc-metrics/Cargo.toml +++ b/libs/alloc-metrics/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] metrics.workspace = true measured.workspace = true +thread_local.workspace = true [dev-dependencies] criterion.workspace = true diff --git a/libs/alloc-metrics/src/counters.rs b/libs/alloc-metrics/src/counters.rs new file mode 100644 index 0000000000..852a83b49f --- /dev/null +++ b/libs/alloc-metrics/src/counters.rs @@ -0,0 +1,48 @@ +use std::marker::PhantomData; + +use measured::{ + FixedCardinalityLabel, LabelGroup, label::StaticLabelSet, metric::MetricFamilyEncoding, +}; +use metrics::{CounterPairAssoc, Dec, Inc, MeasuredCounterPairState}; + +use crate::metric_vec::DenseMetricVec; + +pub struct DenseCounterPairVec< + A: CounterPairAssoc>, + L: FixedCardinalityLabel + LabelGroup, +> { + pub vec: DenseMetricVec, + pub _marker: PhantomData, +} + +impl>, L: FixedCardinalityLabel + LabelGroup> + Default for DenseCounterPairVec +{ + fn default() -> Self { + Self { + vec: DenseMetricVec::new(), + _marker: PhantomData, + } + } +} + +impl ::measured::metric::group::MetricGroup for DenseCounterPairVec +where + T: ::measured::metric::group::Encoding, + ::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding, + A: CounterPairAssoc>, + L: FixedCardinalityLabel + LabelGroup, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { + // write decrement first to avoid a race condition where inc - dec < 0 + T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?; + self.vec + .collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?; + + T::write_help(enc, A::INC_NAME, A::INC_HELP)?; + self.vec + .collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?; + + Ok(()) + } +} diff --git a/libs/alloc-metrics/src/lib.rs b/libs/alloc-metrics/src/lib.rs index 3dde4a11a1..d524c7ddfd 100644 --- a/libs/alloc-metrics/src/lib.rs +++ b/libs/alloc-metrics/src/lib.rs @@ -1,7 +1,7 @@ //! Tagged allocator measurements. +mod counters; mod metric_vec; -mod thread_local; use std::{ alloc::{GlobalAlloc, Layout}, @@ -21,9 +21,7 @@ use measured::{ use metrics::{CounterPairAssoc, MeasuredCounterPairState}; use thread_local::ThreadLocal; -use crate::metric_vec::DenseCounterPairVec; - -type AllocCounter = DenseCounterPairVec, T>; +type AllocCounter = counters::DenseCounterPairVec, T>; pub struct TrackedAllocator { inner: A, @@ -82,8 +80,8 @@ where self.thread_state .get_or_init(ThreadLocal::new) .get_or(|| ThreadState { - counters: DenseCounterPairVec::default(), - global: self.global.get_or_init(DenseCounterPairVec::default), + counters: AllocCounter::default(), + global: self.global.get_or_init(AllocCounter::default), }); self.thread_scope @@ -321,7 +319,7 @@ where CounterState: MetricEncoding, { fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { - let global = self.global.get_or_init(DenseCounterPairVec::default); + let global = self.global.get_or_init(AllocCounter::default); // iterate over all counter threads for s in self.thread_state.get().into_iter().flat_map(|s| s.iter()) { @@ -401,6 +399,67 @@ deallocated_bytes{memory_context="test"} 8 # TYPE allocated_bytes counter allocated_bytes{memory_context="root"} 4 allocated_bytes{memory_context="test"} 8 +"# + ); + } + + #[test] + fn unregistered_thread() { + // Safety: `MemoryContext` upholds the safety requirements. + static GLOBAL: TrackedAllocator = + unsafe { TrackedAllocator::new(System, MemoryContext::Root) }; + + GLOBAL.register_thread(); + + // unregistered thread + std::thread::spawn(|| { + let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) }; + unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) }; + }) + .join() + .unwrap(); + + let mut text = BufferedTextEncoder::new(); + GLOBAL.collect_group_into(&mut text).unwrap(); + let text = String::from_utf8(text.finish().into()).unwrap(); + assert_eq!( + text, + r#"# HELP deallocated_bytes total number of bytes deallocated +# TYPE deallocated_bytes counter +deallocated_bytes{memory_context="root"} 4 +deallocated_bytes{memory_context="test"} 0 + +# HELP allocated_bytes total number of bytes allocated +# TYPE allocated_bytes counter +allocated_bytes{memory_context="root"} 4 +allocated_bytes{memory_context="test"} 0 +"# + ); + } + + #[test] + fn fully_unregistered() { + // Safety: `MemoryContext` upholds the safety requirements. + static GLOBAL: TrackedAllocator = + unsafe { TrackedAllocator::new(System, MemoryContext::Root) }; + + let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) }; + unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) }; + + let mut text = BufferedTextEncoder::new(); + GLOBAL.collect_group_into(&mut text).unwrap(); + let text = String::from_utf8(text.finish().into()).unwrap(); + assert_eq!( + text, + r#"# HELP deallocated_bytes total number of bytes deallocated +# TYPE deallocated_bytes counter +deallocated_bytes{memory_context="root"} 4 +deallocated_bytes{memory_context="test"} 0 + +# HELP allocated_bytes total number of bytes allocated +# TYPE allocated_bytes counter +allocated_bytes{memory_context="root"} 4 +allocated_bytes{memory_context="test"} 0 "# ); } diff --git a/libs/alloc-metrics/src/metric_vec.rs b/libs/alloc-metrics/src/metric_vec.rs index 84fd928ce4..fe8cd84826 100644 --- a/libs/alloc-metrics/src/metric_vec.rs +++ b/libs/alloc-metrics/src/metric_vec.rs @@ -1,16 +1,12 @@ //! Dense metric vec -use std::marker::PhantomData; - use measured::{ FixedCardinalityLabel, LabelGroup, label::{LabelGroupSet, StaticLabelSet}, metric::{ - MetricEncoding, MetricFamilyEncoding, MetricType, counter::CounterState, group::Encoding, - name::MetricNameEncoder, + MetricEncoding, MetricFamilyEncoding, MetricType, group::Encoding, name::MetricNameEncoder, }, }; -use metrics::{CounterPairAssoc, MeasuredCounterPairState}; pub struct DenseMetricVec { metrics: VecInner, @@ -73,11 +69,6 @@ impl DenseMetricVec } } - // /// View the metric metadata - // pub fn metadata(&self) -> &M::Metadata { - // &self.metadata - // } - /// Get an identifier for the specific metric identified by this label group /// /// # Panics @@ -132,129 +123,3 @@ impl, L: FixedCardinalityLabel + LabelGroup, T: Encoding> Ok(()) } } - -pub struct DenseCounterPairVec< - A: CounterPairAssoc>, - L: FixedCardinalityLabel + LabelGroup, -> { - pub vec: DenseMetricVec, - pub _marker: PhantomData, -} - -impl>, L: FixedCardinalityLabel + LabelGroup> - Default for DenseCounterPairVec -{ - fn default() -> Self { - Self { - vec: DenseMetricVec::new(), - _marker: PhantomData, - } - } -} - -// impl>, L: FixedCardinalityLabel + LabelGroup> -// DenseCounterPairVec -// { -// #[inline] -// pub fn inc(&self, labels: ::Group<'_>) { -// let id = self.vec.with_labels(labels); -// self.vec.get_metric(id).inc.inc(); -// } - -// #[inline] -// pub fn dec(&self, labels: ::Group<'_>) { -// let id = self.vec.with_labels(labels); -// self.vec.get_metric(id).dec.inc(); -// } - -// #[inline] -// pub fn inc_by(&self, labels: ::Group<'_>, x: u64) { -// let id = self.vec.with_labels(labels); -// self.vec.get_metric(id).inc.inc_by(x); -// } - -// #[inline] -// pub fn dec_by(&self, labels: ::Group<'_>, x: u64) { -// let id = self.vec.with_labels(labels); -// self.vec.get_metric(id).dec.inc_by(x); -// } -// } - -impl ::measured::metric::group::MetricGroup for DenseCounterPairVec -where - T: ::measured::metric::group::Encoding, - ::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding, - A: CounterPairAssoc>, - L: FixedCardinalityLabel + LabelGroup, -{ - fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { - // write decrement first to avoid a race condition where inc - dec < 0 - T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?; - self.vec - .collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?; - - T::write_help(enc, A::INC_NAME, A::INC_HELP)?; - self.vec - .collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?; - - Ok(()) - } -} - -/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder. -struct Inc(T); -/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder. -struct Dec(T); - -impl Encoding for Inc { - type Err = T::Err; - - fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> { - self.0.write_help(name, help) - } -} - -impl MetricEncoding> for MeasuredCounterPairState -where - CounterState: MetricEncoding, -{ - fn write_type(name: impl MetricNameEncoder, enc: &mut Inc) -> Result<(), T::Err> { - CounterState::write_type(name, &mut enc.0) - } - fn collect_into( - &self, - metadata: &(), - labels: impl LabelGroup, - name: impl MetricNameEncoder, - enc: &mut Inc, - ) -> Result<(), T::Err> { - self.inc.collect_into(metadata, labels, name, &mut enc.0) - } -} - -impl Encoding for Dec { - type Err = T::Err; - - fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> { - self.0.write_help(name, help) - } -} - -/// Write the dec counter to the encoder -impl MetricEncoding> for MeasuredCounterPairState -where - CounterState: MetricEncoding, -{ - fn write_type(name: impl MetricNameEncoder, enc: &mut Dec) -> Result<(), T::Err> { - CounterState::write_type(name, &mut enc.0) - } - fn collect_into( - &self, - metadata: &(), - labels: impl LabelGroup, - name: impl MetricNameEncoder, - enc: &mut Dec, - ) -> Result<(), T::Err> { - self.dec.collect_into(metadata, labels, name, &mut enc.0) - } -} diff --git a/libs/alloc-metrics/src/thread_local.rs b/libs/alloc-metrics/src/thread_local.rs deleted file mode 100644 index 49c37b1b50..0000000000 --- a/libs/alloc-metrics/src/thread_local.rs +++ /dev/null @@ -1,581 +0,0 @@ -//! A vendoring of `thread_local` 1.1.9 with some changes needed for TLS destructors - -// Copyright 2017 Amanieu d'Antras -// -// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be -// copied, modified, or distributed except according to those terms. - -use std::cell::UnsafeCell; -use std::fmt; -use std::iter::FusedIterator; -use std::mem::MaybeUninit; -use std::panic::UnwindSafe; -use std::ptr; -use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; - -// Use usize::BITS once it has stabilized and the MSRV has been bumped. -#[cfg(target_pointer_width = "16")] -const POINTER_WIDTH: u8 = 16; -#[cfg(target_pointer_width = "32")] -const POINTER_WIDTH: u8 = 32; -#[cfg(target_pointer_width = "64")] -const POINTER_WIDTH: u8 = 64; - -/// The total number of buckets stored in each thread local. -/// All buckets combined can hold up to `usize::MAX - 1` entries. -const BUCKETS: usize = (POINTER_WIDTH - 1) as usize; - -/// Thread-local variable wrapper -/// -/// See the [module-level documentation](index.html) for more. -pub(crate) struct ThreadLocal { - /// The buckets in the thread local. The nth bucket contains `2^n` - /// elements. Each bucket is lazily allocated. - buckets: [AtomicPtr>; BUCKETS], - - /// The number of values in the thread local. This can be less than the real number of values, - /// but is never more. - values: AtomicUsize, -} - -struct Entry { - present: AtomicBool, - value: UnsafeCell>, -} - -impl Drop for Entry { - fn drop(&mut self) { - if *self.present.get_mut() { - // safety: If present is true, then this value is allocated - // and we cannot touch it ever again after this function - unsafe { ptr::drop_in_place((*self.value.get()).as_mut_ptr()) } - } - } -} - -// Safety: ThreadLocal is always Sync, even if T isn't -unsafe impl Sync for ThreadLocal {} - -impl Default for ThreadLocal { - fn default() -> ThreadLocal { - ThreadLocal::new() - } -} - -impl Drop for ThreadLocal { - fn drop(&mut self) { - // Free each non-null bucket - for (i, bucket) in self.buckets.iter_mut().enumerate() { - let bucket_ptr = *bucket.get_mut(); - - let this_bucket_size = 1 << i; - - if bucket_ptr.is_null() { - continue; - } - - // Safety: the bucket_ptr is allocated and the bucket size is correct. - unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) }; - } - } -} - -impl ThreadLocal { - /// Creates a new empty `ThreadLocal`. - pub const fn new() -> ThreadLocal { - Self { - buckets: [const { AtomicPtr::new(ptr::null_mut()) }; BUCKETS], - values: AtomicUsize::new(0), - } - } - - /// Returns the element for the current thread, if it exists. - pub fn get(&self) -> Option<&T> { - thread_id::get().and_then(|t| self.get_inner(t)) - } - - /// Returns the element for the current thread, or creates it if it doesn't - /// exist. - pub fn get_or(&self, create: F) -> &T - where - F: FnOnce() -> T, - { - let thread = thread_id::get_or_init(); - if let Some(val) = self.get_inner(thread) { - return val; - } - - self.insert(thread, create()) - } - - fn get_inner(&self, thread: thread_id::Thread) -> Option<&T> { - // Safety: There would need to be isize::MAX+1 threads for the bucket to overflow. - let bucket_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; - let bucket_ptr = bucket_ptr.load(Ordering::Acquire); - if bucket_ptr.is_null() { - return None; - } - // Safety: the bucket always has enough capacity for this index. - // Safety: If present is true, then this entry is allocated and it is safe to read. - unsafe { - let entry = &*bucket_ptr.add(thread.index); - if entry.present.load(Ordering::Relaxed) { - Some(&*(&*entry.value.get()).as_ptr()) - } else { - None - } - } - } - - #[cold] - fn insert(&self, thread: thread_id::Thread, data: T) -> &T { - // Safety: There would need to be isize::MAX+1 threads for the bucket to overflow. - let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; - let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire); - - // If the bucket doesn't already exist, we need to allocate it - let bucket_ptr = if bucket_ptr.is_null() { - let new_bucket = allocate_bucket(thread.bucket_size()); - - match bucket_atomic_ptr.compare_exchange( - ptr::null_mut(), - new_bucket, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => new_bucket, - // If the bucket value changed (from null), that means - // another thread stored a new bucket before we could, - // and we can free our bucket and use that one instead - Err(bucket_ptr) => { - // Safety: the bucket_ptr is allocated and the bucket size is correct. - unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) } - bucket_ptr - } - } - } else { - bucket_ptr - }; - - // Insert the new element into the bucket - // Safety: the bucket always has enough capacity for this index. - let entry = unsafe { &*bucket_ptr.add(thread.index) }; - let value_ptr = entry.value.get(); - // Safety: present is false, so no other threads will be reading this, - // and it is owned by our thread, so no other threads will be writing to this. - unsafe { value_ptr.write(MaybeUninit::new(data)) }; - entry.present.store(true, Ordering::Release); - - self.values.fetch_add(1, Ordering::Release); - - // Safety: present is true, so it is now safe to read. - unsafe { &*(&*value_ptr).as_ptr() } - } - - /// Returns an iterator over the local values of all threads in unspecified - /// order. - /// - /// This call can be done safely, as `T` is required to implement [`Sync`]. - pub fn iter(&self) -> Iter<'_, T> - where - T: Sync, - { - Iter { - thread_local: self, - yielded: 0, - bucket: 0, - bucket_size: 1, - index: 0, - } - } -} - -impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal { - type Item = &'a T; - type IntoIter = Iter<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -impl fmt::Debug for ThreadLocal { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) - } -} - -impl UnwindSafe for ThreadLocal {} - -/// Iterator over the contents of a `ThreadLocal`. -#[derive(Debug)] -pub struct Iter<'a, T: Send + Sync> { - thread_local: &'a ThreadLocal, - yielded: usize, - bucket: usize, - bucket_size: usize, - index: usize, -} - -impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { - type Item = &'a T; - fn next(&mut self) -> Option { - while self.bucket < BUCKETS { - let bucket = self.thread_local.buckets[self.bucket].load(Ordering::Acquire); - - if !bucket.is_null() { - while self.index < self.bucket_size { - // Safety: the bucket always has enough capacity for this index. - let entry = unsafe { &*bucket.add(self.index) }; - self.index += 1; - if entry.present.load(Ordering::Acquire) { - self.yielded += 1; - // Safety: If present is true, then this entry is allocated and it is safe to read. - return Some(unsafe { &*(&*entry.value.get()).as_ptr() }); - } - } - } - - self.bucket_size <<= 1; - self.bucket += 1; - self.index = 0; - } - - None - } - - fn size_hint(&self) -> (usize, Option) { - let total = self.thread_local.values.load(Ordering::Acquire); - (total - self.yielded, None) - } -} -impl FusedIterator for Iter<'_, T> {} - -fn allocate_bucket(size: usize) -> *mut Entry { - Box::into_raw( - (0..size) - .map(|_| Entry:: { - present: AtomicBool::new(false), - value: UnsafeCell::new(MaybeUninit::uninit()), - }) - .collect(), - ) as *mut _ -} - -unsafe fn deallocate_bucket(bucket: *mut Entry, size: usize) { - // Safety: caller must guarantee that bucket is allocated and of the correct size. - let _ = unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket, size)) }; -} - -mod thread_id { - // Copyright 2017 Amanieu d'Antras - // - // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be - // copied, modified, or distributed except according to those terms. - - use super::POINTER_WIDTH; - use std::cell::Cell; - use std::cmp::Reverse; - use std::collections::BinaryHeap; - use std::sync::Mutex; - - /// Thread ID manager which allocates thread IDs. It attempts to aggressively - /// reuse thread IDs where possible to avoid cases where a ThreadLocal grows - /// indefinitely when it is used by many short-lived threads. - struct ThreadIdManager { - free_from: usize, - free_list: Option>>, - } - - impl ThreadIdManager { - const fn new() -> Self { - Self { - free_from: 0, - free_list: None, - } - } - - fn alloc(&mut self) -> usize { - if let Some(id) = self.free_list.as_mut().and_then(|heap| heap.pop()) { - id.0 - } else { - // `free_from` can't overflow as each thread takes up at least 2 bytes of memory and - // thus we can't even have `usize::MAX / 2 + 1` threads. - - let id = self.free_from; - self.free_from += 1; - id - } - } - - fn free(&mut self, id: usize) { - self.free_list - .get_or_insert_with(BinaryHeap::new) - .push(Reverse(id)); - } - } - - static THREAD_ID_MANAGER: Mutex = Mutex::new(ThreadIdManager::new()); - - /// Data which is unique to the current thread while it is running. - /// A thread ID may be reused after a thread exits. - #[derive(Clone, Copy)] - pub(super) struct Thread { - /// The bucket this thread's local storage will be in. - pub(super) bucket: usize, - /// The index into the bucket this thread's local storage is in. - pub(super) index: usize, - } - - impl Thread { - pub(super) fn new(id: usize) -> Self { - let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1; - let bucket_size = 1 << bucket; - let index = id - (bucket_size - 1); - - Self { bucket, index } - } - - /// The size of the bucket this thread's local storage will be in. - pub(super) fn bucket_size(&self) -> usize { - 1 << self.bucket - } - } - - // This is split into 2 thread-local variables so that we can check whether the - // thread is initialized without having to register a thread-local destructor. - // - // This makes the fast path smaller, and it is necessary for GlobalAlloc as we are not allowed - // to use thread locals with destructors during alloc. - thread_local! { static THREAD: Cell> = const { Cell::new(None) }; } - thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; } - - // Guard to ensure the thread ID is released on thread exit. - struct ThreadGuard { - // We keep a copy of the thread ID in the ThreadGuard: we can't - // reliably access THREAD in our Drop impl due to the unpredictable - // order of TLS destructors. - id: Cell, - } - - impl Drop for ThreadGuard { - fn drop(&mut self) { - // Release the thread ID. Any further accesses to the thread ID - // will go through get_slow which will either panic or - // initialize a new ThreadGuard. - let _ = THREAD.try_with(|thread| thread.set(None)); - THREAD_ID_MANAGER.lock().unwrap().free(self.id.get()); - } - } - - /// Returns a thread ID for the current thread. - #[inline] - pub(crate) fn get() -> Option { - THREAD.with(|thread| thread.get()) - } - - /// Returns a thread ID for the current thread, allocating one if needed. - #[inline] - pub(crate) fn get_or_init() -> Thread { - THREAD.with(|thread| { - if let Some(thread) = thread.get() { - thread - } else { - get_slow(thread) - } - }) - } - - /// Out-of-line slow path for allocating a thread ID. - #[cold] - fn get_slow(thread: &Cell>) -> Thread { - let id = THREAD_ID_MANAGER.lock().unwrap().alloc(); - let new = Thread::new(id); - thread.set(Some(new)); - THREAD_GUARD.with(|guard| guard.id.set(id)); - new - } - - #[test] - fn test_thread() { - let thread = Thread::new(0); - assert_eq!(thread.bucket, 0); - assert_eq!(thread.bucket_size(), 1); - assert_eq!(thread.index, 0); - - let thread = Thread::new(1); - assert_eq!(thread.bucket, 1); - assert_eq!(thread.bucket_size(), 2); - assert_eq!(thread.index, 0); - - let thread = Thread::new(2); - assert_eq!(thread.bucket, 1); - assert_eq!(thread.bucket_size(), 2); - assert_eq!(thread.index, 1); - - let thread = Thread::new(3); - assert_eq!(thread.bucket, 2); - assert_eq!(thread.bucket_size(), 4); - assert_eq!(thread.index, 0); - - let thread = Thread::new(19); - assert_eq!(thread.bucket, 4); - assert_eq!(thread.bucket_size(), 16); - assert_eq!(thread.index, 4); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::cell::RefCell; - use std::sync::Arc; - use std::sync::atomic::AtomicUsize; - use std::sync::atomic::Ordering::Relaxed; - use std::thread; - - fn make_create() -> Arc usize + Send + Sync> { - let count = AtomicUsize::new(0); - Arc::new(move || count.fetch_add(1, Relaxed)) - } - - #[test] - fn same_thread() { - let create = make_create(); - let tls = ThreadLocal::new(); - assert_eq!(None, tls.get()); - assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); - assert_eq!(0, *tls.get_or(|| create())); - assert_eq!(Some(&0), tls.get()); - assert_eq!(0, *tls.get_or(|| create())); - assert_eq!(Some(&0), tls.get()); - assert_eq!(0, *tls.get_or(|| create())); - assert_eq!(Some(&0), tls.get()); - assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); - // tls.clear(); - // assert_eq!(None, tls.get()); - } - - #[test] - fn different_thread() { - let create = make_create(); - let tls = Arc::new(ThreadLocal::new()); - assert_eq!(None, tls.get()); - assert_eq!(0, *tls.get_or(|| create())); - assert_eq!(Some(&0), tls.get()); - - let tls2 = tls.clone(); - let create2 = create.clone(); - thread::spawn(move || { - assert_eq!(None, tls2.get()); - assert_eq!(1, *tls2.get_or(|| create2())); - assert_eq!(Some(&1), tls2.get()); - }) - .join() - .unwrap(); - - assert_eq!(Some(&0), tls.get()); - assert_eq!(0, *tls.get_or(|| create())); - } - - #[test] - fn iter() { - let tls = Arc::new(ThreadLocal::new()); - tls.get_or(|| Box::new(1)); - - let tls2 = tls.clone(); - thread::spawn(move || { - tls2.get_or(|| Box::new(2)); - let tls3 = tls2.clone(); - thread::spawn(move || { - tls3.get_or(|| Box::new(3)); - }) - .join() - .unwrap(); - drop(tls2); - }) - .join() - .unwrap(); - - let tls = Arc::try_unwrap(tls).unwrap(); - - let mut v = tls.iter().map(|x| **x).collect::>(); - v.sort_unstable(); - assert_eq!(vec![1, 2, 3], v); - } - - #[test] - fn miri_iter_soundness_check() { - let tls = Arc::new(ThreadLocal::new()); - let _local = tls.get_or(|| Box::new(1)); - - let tls2 = tls.clone(); - let join_1 = thread::spawn(move || { - let _tls = tls2.get_or(|| Box::new(2)); - let iter = tls2.iter(); - for item in iter { - println!("{item:?}"); - } - }); - - let iter = tls.iter(); - for item in iter { - println!("{item:?}"); - } - - join_1.join().ok(); - } - - #[test] - fn test_drop() { - let local = ThreadLocal::new(); - struct Dropped(Arc); - impl Drop for Dropped { - fn drop(&mut self) { - self.0.fetch_add(1, Relaxed); - } - } - - let dropped = Arc::new(AtomicUsize::new(0)); - local.get_or(|| Dropped(dropped.clone())); - assert_eq!(dropped.load(Relaxed), 0); - drop(local); - assert_eq!(dropped.load(Relaxed), 1); - } - - #[test] - fn test_earlyreturn_buckets() { - struct Dropped(Arc); - impl Drop for Dropped { - fn drop(&mut self) { - self.0.fetch_add(1, Relaxed); - } - } - let dropped = Arc::new(AtomicUsize::new(0)); - - // We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used. - // Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`. - let thread = thread_id::Thread::new(1234); - assert!(thread.bucket > 1); - - let local = ThreadLocal::new(); - local.insert(thread, Dropped(dropped.clone())); - - let item = local.iter().next().unwrap(); - assert_eq!(item.0.load(Relaxed), 0); - drop(local); - assert_eq!(dropped.load(Relaxed), 1); - } - - #[test] - fn is_sync() { - fn foo() {} - foo::>(); - foo::>>(); - } -} diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 621d1c6e0a..f8867d49f1 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -619,9 +619,9 @@ impl Drop for MeasuredCounterPairGuard<'_, A> { } /// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder. -struct Inc(T); +pub struct Inc(pub T); /// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder. -struct Dec(T); +pub struct Dec(pub T); impl Encoding for Inc { type Err = T::Err;