Compare commits

..

7 Commits

Author SHA1 Message Date
Christian Schwarz
71bf9cf8ae origin/problame/page-cache-forward-progress/3: trace spans and events only for tests 2023-11-29 11:50:17 +00:00
Christian Schwarz
fd97c98dd9 move into library 2023-11-29 11:50:16 +00:00
Christian Schwarz
05dbff7a18 commented out the check for just-once-polled, works now, don't understand why though 2023-11-29 11:48:22 +00:00
Christian Schwarz
31632502aa fixes 2023-11-29 11:48:22 +00:00
Christian Schwarz
76d3e44588 hand-roll it instead 2023-11-29 11:48:20 +00:00
Christian Schwarz
a5912dcc1b page_cache: find_victim: prevent starvation 2023-11-29 11:45:33 +00:00
Christian Schwarz
da9a88a882 page_cache: ensure forward progress on cache miss 2023-11-29 11:43:28 +00:00
48 changed files with 794 additions and 1384 deletions

View File

@@ -404,7 +404,7 @@ jobs:
uses: ./.github/actions/save-coverage-data
regress-tests:
needs: [ check-permissions, build-neon, tag ]
needs: [ check-permissions, build-neon ]
runs-on: [ self-hosted, gen3, large ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
@@ -436,7 +436,6 @@ jobs:
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
- name: Merge and upload coverage data
if: matrix.build_type == 'debug' && matrix.pg_version == 'v14'

14
Cargo.lock generated
View File

@@ -2610,6 +2610,17 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nostarve_queue"
version = "0.1.0"
dependencies = [
"futures",
"rand 0.8.5",
"scopeguard",
"tokio",
"tracing",
]
[[package]]
name = "notify"
version = "5.2.0"
@@ -2951,6 +2962,7 @@ dependencies = [
"itertools",
"metrics",
"nix 0.26.2",
"nostarve_queue",
"num-traits",
"num_cpus",
"once_cell",
@@ -3011,7 +3023,6 @@ dependencies = [
"serde_with",
"strum",
"strum_macros",
"thiserror",
"utils",
"workspace_hack",
]
@@ -3506,7 +3517,6 @@ dependencies = [
"pbkdf2",
"pin-project-lite",
"postgres-native-tls",
"postgres-protocol",
"postgres_backend",
"pq_proto",
"prometheus",

View File

@@ -27,6 +27,7 @@ members = [
"libs/postgres_ffi/wal_craft",
"libs/vm_monitor",
"libs/walproposer",
"libs/nostarve_queue",
]
[workspace.package]
@@ -37,6 +38,7 @@ license = "Apache-2.0"
[workspace.dependencies]
anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-channel = "1.9.0"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
azure_core = "0.16"
azure_identity = "0.16"
@@ -191,6 +193,7 @@ tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
nostarve_queue = { path = "./libs/nostarve_queue" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }

View File

@@ -714,24 +714,6 @@ RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -
cargo pgrx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
#########################################################################################
#
# Layer "wal2json-build"
# Compile "wal2json" extension
#
#########################################################################################
FROM build-deps AS wal2json-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.gz && \
echo "b516653575541cf221b99cf3f8be9b6821f6dbcfc125675c85f35090f824f00e wal2json_2_5.tar.gz" | sha256sum --check && \
mkdir wal2json-src && cd wal2json-src && tar xvzf ../wal2json_2_5.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/wal2json.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -768,7 +750,6 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

View File

@@ -687,9 +687,6 @@ pub fn handle_extension_neon(client: &mut Client) -> Result<()> {
info!("create neon extension with query: {}", query);
client.simple_query(query)?;
query = "UPDATE pg_extension SET extrelocatable = true WHERE extname = 'neon'";
client.simple_query(query)?;
query = "ALTER EXTENSION neon SET SCHEMA neon";
info!("alter neon extension schema with query: {}", query);
client.simple_query(query)?;

View File

@@ -21,7 +21,7 @@ use pageserver_api::models::{
use pageserver_api::shard::TenantShardId;
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::blocking::{Client, ClientBuilder, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
@@ -99,7 +99,7 @@ impl PageServerNode {
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
conf: conf.clone(),
env: env.clone(),
http_client: Client::new(),
http_client: ClientBuilder::new().timeout(None).build().unwrap(),
http_base_url: format!("http://{}/v1", conf.listen_http_addr),
}
}

View File

@@ -14,6 +14,7 @@ use pageserver_api::models::{
use std::collections::HashMap;
use std::time::Duration;
use utils::{
generation::Generation,
id::{TenantId, TimelineId},
lsn::Lsn,
};
@@ -92,22 +93,6 @@ pub fn migrate_tenant(
// Get a new generation
let attachment_service = AttachmentService::from_env(env);
fn build_location_config(
mode: LocationConfigMode,
generation: Option<u32>,
secondary_conf: Option<LocationConfigSecondary>,
) -> LocationConfig {
LocationConfig {
mode,
generation,
secondary_conf,
tenant_conf: TenantConfig::default(),
shard_number: 0,
shard_count: 0,
shard_stripe_size: 0,
}
}
let previous = attachment_service.inspect(tenant_id)?;
let mut baseline_lsns = None;
if let Some((generation, origin_ps_id)) = &previous {
@@ -116,7 +101,12 @@ pub fn migrate_tenant(
if origin_ps_id == &dest_ps.conf.id {
println!("🔁 Already attached to {origin_ps_id}, freshening...");
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
let dest_conf = LocationConfig {
mode: LocationConfigMode::AttachedSingle,
generation: gen.map(Generation::new),
secondary_conf: None,
tenant_conf: TenantConfig::default(),
};
dest_ps.location_config(tenant_id, dest_conf)?;
println!("✅ Migration complete");
return Ok(());
@@ -124,15 +114,24 @@ pub fn migrate_tenant(
println!("🔁 Switching origin pageserver {origin_ps_id} to stale mode");
let stale_conf =
build_location_config(LocationConfigMode::AttachedStale, Some(*generation), None);
let stale_conf = LocationConfig {
mode: LocationConfigMode::AttachedStale,
generation: Some(Generation::new(*generation)),
secondary_conf: None,
tenant_conf: TenantConfig::default(),
};
origin_ps.location_config(tenant_id, stale_conf)?;
baseline_lsns = Some(get_lsns(tenant_id, &origin_ps)?);
}
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
let dest_conf = build_location_config(LocationConfigMode::AttachedMulti, gen, None);
let dest_conf = LocationConfig {
mode: LocationConfigMode::AttachedMulti,
generation: gen.map(Generation::new),
secondary_conf: None,
tenant_conf: TenantConfig::default(),
};
println!("🔁 Attaching to pageserver {}", dest_ps.conf.id);
dest_ps.location_config(tenant_id, dest_conf)?;
@@ -171,11 +170,12 @@ pub fn migrate_tenant(
}
// Downgrade to a secondary location
let secondary_conf = build_location_config(
LocationConfigMode::Secondary,
None,
Some(LocationConfigSecondary { warm: true }),
);
let secondary_conf = LocationConfig {
mode: LocationConfigMode::Secondary,
generation: None,
secondary_conf: Some(LocationConfigSecondary { warm: true }),
tenant_conf: TenantConfig::default(),
};
println!(
"💤 Switching to secondary mode on pageserver {}",
@@ -188,7 +188,12 @@ pub fn migrate_tenant(
"🔁 Switching to AttachedSingle mode on pageserver {}",
dest_ps.conf.id
);
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
let dest_conf = LocationConfig {
mode: LocationConfigMode::AttachedSingle,
generation: gen.map(Generation::new),
secondary_conf: None,
tenant_conf: TenantConfig::default(),
};
dest_ps.location_config(tenant_id, dest_conf)?;
println!("✅ Migration complete");

View File

@@ -0,0 +1,14 @@
[package]
name = "nostarve_queue"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
scopeguard.workspace = true
tracing.workspace = true
[dev-dependencies]
futures.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time"] }

View File

@@ -0,0 +1,316 @@
//! Synchronization primitive to prevent starvation among concurrent tasks that do the same work.
use std::{
collections::VecDeque,
fmt,
future::poll_fn,
sync::Mutex,
task::{Poll, Waker},
};
pub struct Queue<T> {
inner: Mutex<Inner<T>>,
}
struct Inner<T> {
waiters: VecDeque<usize>,
free: VecDeque<usize>,
slots: Vec<Option<(Option<Waker>, Option<T>)>>,
}
#[derive(Clone, Copy)]
pub struct Position<'q, T> {
idx: usize,
queue: &'q Queue<T>,
}
impl<T> fmt::Debug for Position<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Position").field("idx", &self.idx).finish()
}
}
impl<T> Inner<T> {
#[cfg(not(test))]
#[inline]
fn integrity_check(&self) {}
#[cfg(test)]
fn integrity_check(&self) {
use std::collections::HashSet;
let waiters = self.waiters.iter().copied().collect::<HashSet<_>>();
let free = self.free.iter().copied().collect::<HashSet<_>>();
for (slot_idx, slot) in self.slots.iter().enumerate() {
match slot {
None => {
assert!(!waiters.contains(&slot_idx));
assert!(free.contains(&slot_idx));
}
Some((None, None)) => {
assert!(waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((Some(_), Some(_))) => {
assert!(!waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((Some(_), None)) => {
assert!(waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((None, Some(_))) => {
assert!(!waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
}
}
}
}
impl<T> Queue<T> {
pub fn new(size: usize) -> Self {
Queue {
inner: Mutex::new(Inner {
waiters: VecDeque::new(),
free: (0..size).collect(),
slots: {
let mut v = Vec::with_capacity(size);
v.resize_with(size, || None);
v
},
}),
}
}
pub fn begin(&self) -> Result<Position<T>, ()> {
#[cfg(test)]
tracing::trace!("get in line locking inner");
let mut inner = self.inner.lock().unwrap();
inner.integrity_check();
let my_waitslot_idx = inner
.free
.pop_front()
.expect("can't happen, len(slots) = len(waiters");
inner.waiters.push_back(my_waitslot_idx);
let prev = inner.slots[my_waitslot_idx].replace((None, None));
assert!(prev.is_none());
inner.integrity_check();
Ok(Position {
idx: my_waitslot_idx,
queue: &self,
})
}
}
impl<'q, T> Position<'q, T> {
pub fn complete_and_wait(self, datum: T) -> impl std::future::Future<Output = T> + 'q {
#[cfg(test)]
tracing::trace!("found victim locking waiters");
let mut inner = self.queue.inner.lock().unwrap();
inner.integrity_check();
let winner_idx = inner.waiters.pop_front().expect("we put ourselves in");
#[cfg(test)]
tracing::trace!(winner_idx, "putting victim into next waiters slot");
let winner_slot = inner.slots[winner_idx].as_mut().unwrap();
let prev = winner_slot.1.replace(datum);
assert!(
prev.is_none(),
"ensure we didn't mess up this simple ring buffer structure"
);
if let Some(waker) = winner_slot.0.take() {
#[cfg(test)]
tracing::trace!(winner_idx, "waking up winner");
waker.wake()
}
inner.integrity_check();
drop(inner); // the poll_fn locks it again
let mut poll_num = 0;
let mut drop_guard = Some(scopeguard::guard((), |()| {
panic!("must not drop this future until Ready");
}));
// take the victim that was found by someone else
poll_fn(move |cx| {
let my_waitslot_idx = self.idx;
poll_num += 1;
#[cfg(test)]
tracing::trace!(poll_num, "poll_fn locking waiters");
let mut inner = self.queue.inner.lock().unwrap();
inner.integrity_check();
let my_waitslot = inner.slots[self.idx].as_mut().unwrap();
// assert!(
// poll_num <= 2,
// "once we place the waker in the slot, next wakeup should have a result: {}",
// my_waitslot.1.is_some()
// );
if let Some(res) = my_waitslot.1.take() {
#[cfg(test)]
tracing::trace!(poll_num, "have cache slot");
// above .take() resets the waiters slot to None
debug_assert!(my_waitslot.0.is_none());
debug_assert!(my_waitslot.1.is_none());
inner.slots[my_waitslot_idx] = None;
inner.free.push_back(my_waitslot_idx);
let _ = scopeguard::ScopeGuard::into_inner(drop_guard.take().unwrap());
inner.integrity_check();
return Poll::Ready(res);
}
// assert_eq!(poll_num, 1);
if !my_waitslot
.0
.as_ref()
.map(|existing| cx.waker().will_wake(existing))
.unwrap_or(false)
{
let prev = my_waitslot.0.replace(cx.waker().clone());
#[cfg(test)]
tracing::trace!(poll_num, prev_is_some = prev.is_some(), "updating waker");
}
inner.integrity_check();
#[cfg(test)]
tracing::trace!(poll_num, "waiting to be woken up");
Poll::Pending
})
}
}
#[cfg(test)]
mod test {
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::Poll,
time::Duration,
};
use rand::RngCore;
#[tokio::test]
async fn in_order_completion_and_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
assert_eq!(q1.complete_and_wait(23).await, 23);
assert_eq!(q2.complete_and_wait(42).await, 42);
}
#[tokio::test]
async fn out_of_order_completion_and_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
let mut q2compfut = q2.complete_and_wait(23);
match futures::poll!(&mut q2compfut) {
Poll::Pending => {}
Poll::Ready(_) => panic!("should not be ready yet, it's queued after q1"),
}
let q1res = q1.complete_and_wait(42).await;
assert_eq!(q1res, 23);
let q2res = q2compfut.await;
assert_eq!(q2res, 42);
}
#[tokio::test]
async fn in_order_completion_out_of_order_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
let mut q1compfut = q1.complete_and_wait(23);
let mut q2compfut = q2.complete_and_wait(42);
match futures::poll!(&mut q2compfut) {
Poll::Pending => {
unreachable!("q2 should be ready, it wasn't first but q1 is serviced already")
}
Poll::Ready(x) => assert_eq!(x, 42),
}
assert_eq!(futures::poll!(&mut q1compfut), Poll::Ready(23));
}
#[tokio::test(flavor = "multi_thread")]
async fn stress() {
let ntasks = 8;
let queue_size = 8;
let queue = Arc::new(super::Queue::new(queue_size));
let stop = Arc::new(AtomicBool::new(false));
let mut tasks = vec![];
for i in 0..ntasks {
let jh = tokio::spawn({
let queue = Arc::clone(&queue);
let stop = Arc::clone(&stop);
async move {
while !stop.load(Ordering::Relaxed) {
let q = queue.begin().unwrap();
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
std::hint::spin_loop();
}
q.complete_and_wait(i).await;
tokio::task::yield_now().await;
}
}
});
tasks.push(jh);
}
tokio::time::sleep(Duration::from_secs(10)).await;
stop.store(true, Ordering::Relaxed);
for t in tasks {
t.await.unwrap();
}
}
#[test]
fn stress_two_runtimes_shared_queue() {
std::thread::scope(|s| {
let ntasks = 8;
let queue_size = 8;
let queue = Arc::new(super::Queue::new(queue_size));
let stop = Arc::new(AtomicBool::new(false));
for i in 0..ntasks {
s.spawn({
let queue = Arc::clone(&queue);
let stop = Arc::clone(&stop);
move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
while !stop.load(Ordering::Relaxed) {
let q = queue.begin().unwrap();
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
std::hint::spin_loop();
}
q.complete_and_wait(i).await;
tokio::task::yield_now().await;
}
});
}
});
}
std::thread::sleep(Duration::from_secs(10));
stop.store(true, Ordering::Relaxed);
});
}
}

View File

@@ -18,7 +18,6 @@ enum-map.workspace = true
strum.workspace = true
strum_macros.workspace = true
hex.workspace = true
thiserror.workspace = true
workspace_hack.workspace = true

View File

@@ -10,6 +10,7 @@ use serde_with::serde_as;
use strum_macros;
use utils::{
completion,
generation::Generation,
history_buffer::HistoryBufferWithDropCounter,
id::{NodeId, TenantId, TimelineId},
lsn::Lsn,
@@ -261,19 +262,10 @@ pub struct LocationConfig {
pub mode: LocationConfigMode,
/// If attaching, in what generation?
#[serde(default)]
pub generation: Option<u32>,
pub generation: Option<Generation>,
#[serde(default)]
pub secondary_conf: Option<LocationConfigSecondary>,
// Shard parameters: if shard_count is nonzero, then other shard_* fields
// must be set accurately.
#[serde(default)]
pub shard_number: u8,
#[serde(default)]
pub shard_count: u8,
#[serde(default)]
pub shard_stripe_size: u32,
// If requesting mode `Secondary`, configuration for that.
// Custom storage configuration for the tenant, if any
pub tenant_conf: TenantConfig,

View File

@@ -2,7 +2,6 @@ use std::{ops::RangeInclusive, str::FromStr};
use hex::FromHex;
use serde::{Deserialize, Serialize};
use thiserror;
use utils::id::TenantId;
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
@@ -140,89 +139,6 @@ impl From<[u8; 18]> for TenantShardId {
}
}
/// For use within the context of a particular tenant, when we need to know which
/// shard we're dealing with, but do not need to know the full ShardIdentity (because
/// we won't be doing any page->shard mapping), and do not need to know the fully qualified
/// TenantShardId.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
pub struct ShardIndex {
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
}
impl ShardIndex {
pub fn new(number: ShardNumber, count: ShardCount) -> Self {
Self {
shard_number: number,
shard_count: count,
}
}
pub fn unsharded() -> Self {
Self {
shard_number: ShardNumber(0),
shard_count: ShardCount(0),
}
}
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
}
/// For use in constructing remote storage paths: concatenate this with a TenantId
/// to get a fully qualified TenantShardId.
///
/// Backward compat: this function returns an empty string if Self::is_unsharded, such
/// that the legacy pre-sharding remote key format is preserved.
pub fn get_suffix(&self) -> String {
if self.is_unsharded() {
"".to_string()
} else {
format!("-{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
}
}
}
impl std::fmt::Display for ShardIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
}
}
impl std::fmt::Debug for ShardIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Debug is the same as Display: the compact hex representation
write!(f, "{}", self)
}
}
impl std::str::FromStr for ShardIndex {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Expect format: 1 byte shard number, 1 byte shard count
if s.len() == 4 {
let bytes = s.as_bytes();
let mut shard_parts: [u8; 2] = [0u8; 2];
hex::decode_to_slice(bytes, &mut shard_parts)?;
Ok(Self {
shard_number: ShardNumber(shard_parts[0]),
shard_count: ShardCount(shard_parts[1]),
})
} else {
Err(hex::FromHexError::InvalidStringLength)
}
}
}
impl From<[u8; 2]> for ShardIndex {
fn from(b: [u8; 2]) -> Self {
Self {
shard_number: ShardNumber(b[0]),
shard_count: ShardCount(b[1]),
}
}
}
impl Serialize for TenantShardId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -293,151 +209,6 @@ impl<'de> Deserialize<'de> for TenantShardId {
}
}
/// Stripe size in number of pages
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardStripeSize(pub u32);
/// Layout version: for future upgrades where we might change how the key->shard mapping works
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardLayout(u8);
const LAYOUT_V1: ShardLayout = ShardLayout(1);
/// Default stripe size in pages: 256MiB divided by 8kiB page size.
const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
/// The ShardIdentity contains the information needed for one member of map
/// to resolve a key to a shard, and then check whether that shard is ==self.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardIdentity {
pub layout: ShardLayout,
pub number: ShardNumber,
pub count: ShardCount,
pub stripe_size: ShardStripeSize,
}
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ShardConfigError {
#[error("Invalid shard count")]
InvalidCount,
#[error("Invalid shard number")]
InvalidNumber,
#[error("Invalid stripe size")]
InvalidStripeSize,
}
impl ShardIdentity {
/// An identity with number=0 count=0 is a "none" identity, which represents legacy
/// tenants. Modern single-shard tenants should not use this: they should
/// have number=0 count=1.
pub fn unsharded() -> Self {
Self {
number: ShardNumber(0),
count: ShardCount(0),
layout: LAYOUT_V1,
stripe_size: DEFAULT_STRIPE_SIZE,
}
}
pub fn is_unsharded(&self) -> bool {
self.number == ShardNumber(0) && self.count == ShardCount(0)
}
/// Count must be nonzero, and number must be < count. To construct
/// the legacy case (count==0), use Self::unsharded instead.
pub fn new(
number: ShardNumber,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Result<Self, ShardConfigError> {
if count.0 == 0 {
Err(ShardConfigError::InvalidCount)
} else if number.0 > count.0 - 1 {
Err(ShardConfigError::InvalidNumber)
} else if stripe_size.0 == 0 {
Err(ShardConfigError::InvalidStripeSize)
} else {
Ok(Self {
number,
count,
layout: LAYOUT_V1,
stripe_size,
})
}
}
}
impl Serialize for ShardIndex {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.collect_str(self)
} else {
// Binary encoding is not used in index_part.json, but is included in anticipation of
// switching various structures (e.g. inter-process communication, remote metadata) to more
// compact binary encodings in future.
let mut packed: [u8; 2] = [0; 2];
packed[0] = self.shard_number.0;
packed[1] = self.shard_count.0;
packed.serialize(serializer)
}
}
}
impl<'de> Deserialize<'de> for ShardIndex {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct IdVisitor {
is_human_readable_deserializer: bool,
}
impl<'de> serde::de::Visitor<'de> for IdVisitor {
type Value = ShardIndex;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.is_human_readable_deserializer {
formatter.write_str("value in form of hex string")
} else {
formatter.write_str("value in form of integer array([u8; 2])")
}
}
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let s = serde::de::value::SeqAccessDeserializer::new(seq);
let id: [u8; 2] = Deserialize::deserialize(s)?;
Ok(ShardIndex::from(id))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
ShardIndex::from_str(v).map_err(E::custom)
}
}
if deserializer.is_human_readable() {
deserializer.deserialize_str(IdVisitor {
is_human_readable_deserializer: true,
})
} else {
deserializer.deserialize_tuple(
2,
IdVisitor {
is_human_readable_deserializer: false,
},
)
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
@@ -547,66 +318,4 @@ mod tests {
Ok(())
}
#[test]
fn shard_identity_validation() -> Result<(), ShardConfigError> {
// Happy cases
ShardIdentity::new(ShardNumber(0), ShardCount(1), DEFAULT_STRIPE_SIZE)?;
ShardIdentity::new(ShardNumber(0), ShardCount(1), ShardStripeSize(1))?;
ShardIdentity::new(ShardNumber(254), ShardCount(255), ShardStripeSize(1))?;
assert_eq!(
ShardIdentity::new(ShardNumber(0), ShardCount(0), DEFAULT_STRIPE_SIZE),
Err(ShardConfigError::InvalidCount)
);
assert_eq!(
ShardIdentity::new(ShardNumber(10), ShardCount(10), DEFAULT_STRIPE_SIZE),
Err(ShardConfigError::InvalidNumber)
);
assert_eq!(
ShardIdentity::new(ShardNumber(11), ShardCount(10), DEFAULT_STRIPE_SIZE),
Err(ShardConfigError::InvalidNumber)
);
assert_eq!(
ShardIdentity::new(ShardNumber(255), ShardCount(255), DEFAULT_STRIPE_SIZE),
Err(ShardConfigError::InvalidNumber)
);
assert_eq!(
ShardIdentity::new(ShardNumber(0), ShardCount(1), ShardStripeSize(0)),
Err(ShardConfigError::InvalidStripeSize)
);
Ok(())
}
#[test]
fn shard_index_human_encoding() -> Result<(), hex::FromHexError> {
let example = ShardIndex {
shard_number: ShardNumber(13),
shard_count: ShardCount(17),
};
let expected: String = "0d11".to_string();
let encoded = format!("{example}");
assert_eq!(&encoded, &expected);
let decoded = ShardIndex::from_str(&encoded)?;
assert_eq!(example, decoded);
Ok(())
}
#[test]
fn shard_index_binary_encoding() -> Result<(), hex::FromHexError> {
let example = ShardIndex {
shard_number: ShardNumber(13),
shard_count: ShardCount(17),
};
let expected: [u8; 2] = [0x0d, 0x11];
let encoded = bincode::serialize(&example).unwrap();
assert_eq!(Hex(&encoded), Hex(&expected));
let decoded = bincode::deserialize(&encoded).unwrap();
assert_eq!(example, decoded);
Ok(())
}
}

View File

@@ -37,6 +37,7 @@ humantime-serde.workspace = true
hyper.workspace = true
itertools.workspace = true
nix.workspace = true
nostarve_queue.workspace = true
# hack to get the number of worker threads tokio uses
num_cpus = { version = "1.15" }
num-traits.workspace = true

View File

@@ -10,7 +10,6 @@ use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::metrics;
use crate::tenant::remote_timeline_client::remote_layer_path;
use crate::tenant::remote_timeline_client::remote_timeline_path;
use crate::tenant::remote_timeline_client::LayerFileMetadata;
use crate::virtual_file::MaybeFatalIo;
use crate::virtual_file::VirtualFile;
use anyhow::Context;
@@ -510,19 +509,18 @@ impl DeletionQueueClient {
tenant_id: TenantId,
timeline_id: TimelineId,
current_generation: Generation,
layers: Vec<(LayerFileName, LayerFileMetadata)>,
layers: Vec<(LayerFileName, Generation)>,
) -> Result<(), DeletionQueueError> {
if current_generation.is_none() {
debug!("Enqueuing deletions in legacy mode, skipping queue");
let mut layer_paths = Vec::new();
for (layer, meta) in layers {
for (layer, generation) in layers {
layer_paths.push(remote_layer_path(
&tenant_id,
&timeline_id,
meta.shard,
&layer,
meta.generation,
generation,
));
}
self.push_immediate(layer_paths).await?;
@@ -542,7 +540,7 @@ impl DeletionQueueClient {
tenant_id: TenantId,
timeline_id: TimelineId,
current_generation: Generation,
layers: Vec<(LayerFileName, LayerFileMetadata)>,
layers: Vec<(LayerFileName, Generation)>,
) -> Result<(), DeletionQueueError> {
metrics::DELETION_QUEUE
.keys_submitted
@@ -753,7 +751,6 @@ impl DeletionQueue {
mod test {
use camino::Utf8Path;
use hex_literal::hex;
use pageserver_api::shard::ShardIndex;
use std::{io::ErrorKind, time::Duration};
use tracing::info;
@@ -993,8 +990,6 @@ mod test {
// we delete, and the generation of the running Tenant.
let layer_generation = Generation::new(0xdeadbeef);
let now_generation = Generation::new(0xfeedbeef);
let layer_metadata =
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
let remote_layer_file_name_1 =
format!("{}{}", layer_file_name_1, layer_generation.get_suffix());
@@ -1018,7 +1013,7 @@ mod test {
tenant_id,
TIMELINE_ID,
now_generation,
[(layer_file_name_1.clone(), layer_metadata)].to_vec(),
[(layer_file_name_1.clone(), layer_generation)].to_vec(),
)
.await?;
assert_remote_files(&[&remote_layer_file_name_1], &remote_timeline_path);
@@ -1057,8 +1052,6 @@ mod test {
let stale_generation = latest_generation.previous();
// Generation that our example layer file was written with
let layer_generation = stale_generation.previous();
let layer_metadata =
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
ctx.set_latest_generation(latest_generation);
@@ -1076,7 +1069,7 @@ mod test {
tenant_id,
TIMELINE_ID,
stale_generation,
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
)
.await?;
@@ -1091,7 +1084,7 @@ mod test {
tenant_id,
TIMELINE_ID,
latest_generation,
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
)
.await?;
@@ -1118,8 +1111,6 @@ mod test {
let layer_generation = Generation::new(0xdeadbeef);
let now_generation = Generation::new(0xfeedbeef);
let layer_metadata =
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
// Inject a deletion in the generation before generation_now: after restart,
// this deletion should _not_ get executed (only the immediately previous
@@ -1131,7 +1122,7 @@ mod test {
tenant_id,
TIMELINE_ID,
now_generation.previous(),
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
)
.await?;
@@ -1145,7 +1136,7 @@ mod test {
tenant_id,
TIMELINE_ID,
now_generation,
[(EXAMPLE_LAYER_NAME_ALT.clone(), layer_metadata.clone())].to_vec(),
[(EXAMPLE_LAYER_NAME_ALT.clone(), layer_generation)].to_vec(),
)
.await?;
@@ -1235,13 +1226,12 @@ pub(crate) mod mock {
match msg {
ListWriterQueueMessage::Delete(op) => {
let mut objects = op.objects;
for (layer, meta) in op.layers {
for (layer, generation) in op.layers {
objects.push(remote_layer_path(
&op.tenant_id,
&op.timeline_id,
meta.shard,
&layer,
meta.generation,
generation,
));
}

View File

@@ -33,7 +33,6 @@ use crate::config::PageServerConf;
use crate::deletion_queue::TEMP_SUFFIX;
use crate::metrics;
use crate::tenant::remote_timeline_client::remote_layer_path;
use crate::tenant::remote_timeline_client::LayerFileMetadata;
use crate::tenant::storage_layer::LayerFileName;
use crate::virtual_file::on_fatal_io_error;
use crate::virtual_file::MaybeFatalIo;
@@ -59,7 +58,7 @@ pub(super) struct DeletionOp {
// `layers` and `objects` are both just lists of objects. `layers` is used if you do not
// have a config object handy to project it to a remote key, and need the consuming worker
// to do it for you.
pub(super) layers: Vec<(LayerFileName, LayerFileMetadata)>,
pub(super) layers: Vec<(LayerFileName, Generation)>,
pub(super) objects: Vec<RemotePath>,
/// The _current_ generation of the Tenant attachment in which we are enqueuing
@@ -388,13 +387,12 @@ impl ListWriter {
);
let mut layer_paths = Vec::new();
for (layer, meta) in op.layers {
for (layer, generation) in op.layers {
layer_paths.push(remote_layer_path(
&op.tenant_id,
&op.timeline_id,
meta.shard,
&layer,
meta.generation,
generation,
));
}
layer_paths.extend(op.objects);

View File

@@ -314,7 +314,6 @@ static PAGE_CACHE_ERRORS: Lazy<IntCounterVec> = Lazy::new(|| {
#[strum(serialize_all = "kebab_case")]
pub(crate) enum PageCacheErrorKind {
AcquirePinnedSlotTimeout,
EvictIterLimit,
}
pub(crate) fn page_cache_errors_inc(error_kind: PageCacheErrorKind) {

View File

@@ -83,6 +83,7 @@ use std::{
use anyhow::Context;
use once_cell::sync::OnceCell;
use tracing::instrument;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
@@ -252,6 +253,9 @@ pub struct PageCache {
next_evict_slot: AtomicUsize,
size_metrics: &'static PageCacheSizeMetrics,
find_victim_waiters:
nostarve_queue::Queue<(usize, tokio::sync::RwLockWriteGuard<'static, SlotInner>)>,
}
struct PinnedSlotsPermit(tokio::sync::OwnedSemaphorePermit);
@@ -430,8 +434,9 @@ impl PageCache {
///
/// Store an image of the given page in the cache.
///
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(%key, %lsn)))]
pub async fn memorize_materialized_page(
&self,
&'static self,
tenant_id: TenantId,
timeline_id: TimelineId,
key: Key,
@@ -522,8 +527,9 @@ impl PageCache {
// Section 1.2: Public interface functions for working with immutable file pages.
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(?file_id, ?blkno)))]
pub async fn read_immutable_buf(
&self,
&'static self,
file_id: FileId,
blkno: u32,
ctx: &RequestContext,
@@ -629,7 +635,7 @@ impl PageCache {
/// ```
///
async fn lock_for_read(
&self,
&'static self,
cache_key: &mut CacheKey,
ctx: &RequestContext,
) -> anyhow::Result<ReadBufResult> {
@@ -851,10 +857,15 @@ impl PageCache {
///
/// On return, the slot is empty and write-locked.
async fn find_victim(
&self,
&'static self,
_permit_witness: &PinnedSlotsPermit,
) -> anyhow::Result<(usize, tokio::sync::RwLockWriteGuard<SlotInner>)> {
let iter_limit = self.slots.len() * 10;
let nostarve_position = self.find_victim_waiters.begin()
.expect("we initialize the nostarve queue to the same size as the slots semaphore, and the caller is presenting a permit");
let span = tracing::info_span!("find_victim", ?nostarve_position);
let _enter = span.enter();
let mut iters = 0;
loop {
iters += 1;
@@ -866,41 +877,8 @@ impl PageCache {
let mut inner = match slot.inner.try_write() {
Ok(inner) => inner,
Err(_err) => {
if iters > iter_limit {
// NB: Even with the permits, there's no hard guarantee that we will find a slot with
// any particular number of iterations: other threads might race ahead and acquire and
// release pins just as we're scanning the array.
//
// Imagine that nslots is 2, and as starting point, usage_count==1 on all
// slots. There are two threads running concurrently, A and B. A has just
// acquired the permit from the semaphore.
//
// A: Look at slot 1. Its usage_count == 1, so decrement it to zero, and continue the search
// B: Acquire permit.
// B: Look at slot 2, decrement its usage_count to zero and continue the search
// B: Look at slot 1. Its usage_count is zero, so pin it and bump up its usage_count to 1.
// B: Release pin and permit again
// B: Acquire permit.
// B: Look at slot 2. Its usage_count is zero, so pin it and bump up its usage_count to 1.
// B: Release pin and permit again
//
// Now we're back in the starting situation that both slots have
// usage_count 1, but A has now been through one iteration of the
// find_victim() loop. This can repeat indefinitely and on each
// iteration, A's iteration count increases by one.
//
// So, even though the semaphore for the permits is fair, the victim search
// itself happens in parallel and is not fair.
// Hence even with a permit, a task can theoretically be starved.
// To avoid this, we'd need tokio to give priority to tasks that are holding
// permits for longer.
// Note that just yielding to tokio during iteration without such
// priority boosting is likely counter-productive. We'd just give more opportunities
// for B to bump usage count, further starving A.
crate::metrics::page_cache_errors_inc(
crate::metrics::PageCacheErrorKind::EvictIterLimit,
);
anyhow::bail!("exceeded evict iter limit");
if iters > self.slots.len() * (MAX_USAGE_COUNT as usize) {
unreachable!("find_victim_waiters prevents starvation");
}
continue;
}
@@ -911,7 +889,8 @@ impl PageCache {
inner.key = None;
}
crate::metrics::PAGE_CACHE_FIND_VICTIMS_ITERS_TOTAL.inc_by(iters as u64);
return Ok((slot_idx, inner));
return Ok(nostarve_position.complete_and_wait((slot_idx, inner)).await);
}
}
}
@@ -955,6 +934,7 @@ impl PageCache {
next_evict_slot: AtomicUsize::new(0),
size_metrics,
pinned_slots: Arc::new(tokio::sync::Semaphore::new(num_pages)),
find_victim_waiters: ::nostarve_queue::Queue::new(num_pages),
}
}
}

View File

@@ -2908,7 +2908,7 @@ impl Tenant {
};
// create a `tenant/{tenant_id}/timelines/basebackup-{timeline_id}.{TEMP_FILE_SUFFIX}/`
// temporary directory for basebackup files for the given timeline.
let pgdata_path = path_with_suffix_extension(
let initdb_path = path_with_suffix_extension(
self.conf
.timelines_path(&self.tenant_id)
.join(format!("basebackup-{timeline_id}")),
@@ -2917,25 +2917,26 @@ impl Tenant {
// an uninit mark was placed before, nothing else can access this timeline files
// current initdb was not run yet, so remove whatever was left from the previous runs
if pgdata_path.exists() {
fs::remove_dir_all(&pgdata_path).with_context(|| {
format!("Failed to remove already existing initdb directory: {pgdata_path}")
if initdb_path.exists() {
fs::remove_dir_all(&initdb_path).with_context(|| {
format!("Failed to remove already existing initdb directory: {initdb_path}")
})?;
}
// Init temporarily repo to get bootstrap data, this creates a directory in the `pgdata_path` path
run_initdb(self.conf, &pgdata_path, pg_version)?;
// Init temporarily repo to get bootstrap data, this creates a directory in the `initdb_path` path
run_initdb(self.conf, &initdb_path, pg_version)?;
// this new directory is very temporary, set to remove it immediately after bootstrap, we don't need it
scopeguard::defer! {
if let Err(e) = fs::remove_dir_all(&pgdata_path) {
if let Err(e) = fs::remove_dir_all(&initdb_path) {
// this is unlikely, but we will remove the directory on pageserver restart or another bootstrap call
error!("Failed to remove temporary initdb directory '{pgdata_path}': {e}");
error!("Failed to remove temporary initdb directory '{initdb_path}': {e}");
}
}
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(&pgdata_path)?.align();
let pgdata_path = &initdb_path;
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(pgdata_path)?.align();
// Upload the created data dir to S3
if let Some(storage) = &self.remote_storage {
let pgdata_zstd = import_datadir::create_tar_zst(&pgdata_path).await?;
let pgdata_zstd = import_datadir::create_tar_zst(pgdata_path).await?;
let pgdata_zstd = Bytes::from(pgdata_zstd);
backoff::retry(
|| async {
@@ -2985,7 +2986,7 @@ impl Tenant {
import_datadir::import_timeline_from_postgres_datadir(
unfinished_timeline,
&pgdata_path,
pgdata_path,
pgdata_lsn,
ctx,
)
@@ -3468,7 +3469,6 @@ pub async fn dump_layerfile_from_path(
pub(crate) mod harness {
use bytes::{Bytes, BytesMut};
use once_cell::sync::OnceCell;
use pageserver_api::shard::ShardIndex;
use std::fs;
use std::sync::Arc;
use utils::logging;
@@ -3535,7 +3535,6 @@ pub(crate) mod harness {
pub tenant_conf: TenantConf,
pub tenant_id: TenantId,
pub generation: Generation,
pub shard: ShardIndex,
pub remote_storage: GenericRemoteStorage,
pub remote_fs_dir: Utf8PathBuf,
pub deletion_queue: MockDeletionQueue,
@@ -3595,7 +3594,6 @@ pub(crate) mod harness {
tenant_conf,
tenant_id,
generation: Generation::new(0xdeadbeef),
shard: ShardIndex::unsharded(),
remote_storage,
remote_fs_dir,
deletion_queue,

View File

@@ -10,7 +10,6 @@
//!
use anyhow::Context;
use pageserver_api::models;
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize};
use serde::{Deserialize, Serialize};
use std::num::NonZeroU64;
use std::time::Duration;
@@ -89,14 +88,6 @@ pub(crate) struct LocationConf {
/// The location-specific part of the configuration, describes the operating
/// mode of this pageserver for this tenant.
pub(crate) mode: LocationMode,
/// The detailed shard identity. This structure is already scoped within
/// a TenantShardId, but we need the full ShardIdentity to enable calculating
/// key->shard mappings.
#[serde(default = "ShardIdentity::unsharded")]
#[serde(skip_serializing_if = "ShardIdentity::is_unsharded")]
pub(crate) shard: ShardIdentity,
/// The pan-cluster tenant configuration, the same on all locations
pub(crate) tenant_conf: TenantConfOpt,
}
@@ -169,8 +160,6 @@ impl LocationConf {
generation,
attach_mode: AttachmentMode::Single,
}),
// Legacy configuration loads are always from tenants created before sharding existed.
shard: ShardIdentity::unsharded(),
tenant_conf,
}
}
@@ -198,7 +187,6 @@ impl LocationConf {
fn get_generation(conf: &'_ models::LocationConfig) -> Result<Generation, anyhow::Error> {
conf.generation
.map(Generation::new)
.ok_or_else(|| anyhow::anyhow!("Generation must be set when attaching"))
}
@@ -238,21 +226,7 @@ impl LocationConf {
}
};
let shard = if conf.shard_count == 0 {
ShardIdentity::unsharded()
} else {
ShardIdentity::new(
ShardNumber(conf.shard_number),
ShardCount(conf.shard_count),
ShardStripeSize(conf.shard_stripe_size),
)?
};
Ok(Self {
shard,
mode,
tenant_conf,
})
Ok(Self { mode, tenant_conf })
}
}
@@ -267,7 +241,6 @@ impl Default for LocationConf {
attach_mode: AttachmentMode::Single,
}),
tenant_conf: TenantConfOpt::default(),
shard: ShardIdentity::unsharded(),
}
}
}

View File

@@ -188,7 +188,6 @@ use anyhow::Context;
use camino::Utf8Path;
use chrono::{NaiveDateTime, Utc};
use pageserver_api::shard::ShardIndex;
use scopeguard::ScopeGuard;
use tokio_util::sync::CancellationToken;
pub(crate) use upload::upload_initdb_dir;
@@ -403,11 +402,6 @@ impl RemoteTimelineClient {
Ok(())
}
pub(crate) fn get_shard_index(&self) -> ShardIndex {
// TODO: carry this on the struct
ShardIndex::unsharded()
}
pub fn remote_consistent_lsn_projected(&self) -> Option<Lsn> {
match &mut *self.upload_queue.lock().unwrap() {
UploadQueue::Uninitialized => None,
@@ -471,7 +465,6 @@ impl RemoteTimelineClient {
&self.storage_impl,
&self.tenant_id,
&self.timeline_id,
self.get_shard_index(),
self.generation,
cancel,
)
@@ -664,10 +657,10 @@ impl RemoteTimelineClient {
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
let with_metadata =
let with_generations =
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names.iter().cloned());
self.schedule_deletion_of_unlinked0(upload_queue, with_metadata);
self.schedule_deletion_of_unlinked0(upload_queue, with_generations);
// Launch the tasks immediately, if possible
self.launch_queued_tasks(upload_queue);
@@ -702,7 +695,7 @@ impl RemoteTimelineClient {
self: &Arc<Self>,
upload_queue: &mut UploadQueueInitialized,
names: I,
) -> Vec<(LayerFileName, LayerFileMetadata)>
) -> Vec<(LayerFileName, Generation)>
where
I: IntoIterator<Item = LayerFileName>,
{
@@ -710,17 +703,16 @@ impl RemoteTimelineClient {
// so we don't need update it. Just serialize it.
let metadata = upload_queue.latest_metadata.clone();
// Decorate our list of names with each name's metadata, dropping
// names that are unexpectedly missing from our metadata. This metadata
// is later used when physically deleting layers, to construct key paths.
let with_metadata: Vec<_> = names
// Decorate our list of names with each name's generation, dropping
// names that are unexpectedly missing from our metadata.
let with_generations: Vec<_> = names
.into_iter()
.filter_map(|name| {
let meta = upload_queue.latest_files.remove(&name);
if let Some(meta) = meta {
upload_queue.latest_files_changes_since_metadata_upload_scheduled += 1;
Some((name, meta))
Some((name, meta.generation))
} else {
// This can only happen if we forgot to to schedule the file upload
// before scheduling the delete. Log it because it is a rare/strange
@@ -733,10 +725,9 @@ impl RemoteTimelineClient {
.collect();
#[cfg(feature = "testing")]
for (name, metadata) in &with_metadata {
let gen = metadata.generation;
if let Some(unexpected) = upload_queue.dangling_files.insert(name.to_owned(), gen) {
if unexpected == gen {
for (name, gen) in &with_generations {
if let Some(unexpected) = upload_queue.dangling_files.insert(name.to_owned(), *gen) {
if &unexpected == gen {
tracing::error!("{name} was unlinked twice with same generation");
} else {
tracing::error!("{name} was unlinked twice with different generations {gen:?} and {unexpected:?}");
@@ -751,14 +742,14 @@ impl RemoteTimelineClient {
self.schedule_index_upload(upload_queue, metadata);
}
with_metadata
with_generations
}
/// Schedules deletion for layer files which have previously been unlinked from the
/// `index_part.json` with [`Self::schedule_gc_update`] or [`Self::schedule_compaction_update`].
pub(crate) fn schedule_deletion_of_unlinked(
self: &Arc<Self>,
layers: Vec<(LayerFileName, LayerFileMetadata)>,
layers: Vec<(LayerFileName, Generation)>,
) -> anyhow::Result<()> {
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
@@ -771,22 +762,16 @@ impl RemoteTimelineClient {
fn schedule_deletion_of_unlinked0(
self: &Arc<Self>,
upload_queue: &mut UploadQueueInitialized,
with_metadata: Vec<(LayerFileName, LayerFileMetadata)>,
with_generations: Vec<(LayerFileName, Generation)>,
) {
for (name, meta) in &with_metadata {
info!(
"scheduling deletion of layer {}{} (shard {})",
name,
meta.generation.get_suffix(),
meta.shard
);
for (name, gen) in &with_generations {
info!("scheduling deletion of layer {}{}", name, gen.get_suffix());
}
#[cfg(feature = "testing")]
for (name, meta) in &with_metadata {
let gen = meta.generation;
for (name, gen) in &with_generations {
match upload_queue.dangling_files.remove(name) {
Some(same) if same == gen => { /* expected */ }
Some(same) if &same == gen => { /* expected */ }
Some(other) => {
tracing::error!("{name} was unlinked with {other:?} but deleted with {gen:?}");
}
@@ -798,7 +783,7 @@ impl RemoteTimelineClient {
// schedule the actual deletions
let op = UploadOp::Delete(Delete {
layers: with_metadata,
layers: with_generations,
});
self.calls_unfinished_metric_begin(&op);
upload_queue.queued_operations.push_back(op);
@@ -919,7 +904,6 @@ impl RemoteTimelineClient {
&self.storage_impl,
&self.tenant_id,
&self.timeline_id,
self.get_shard_index(),
self.generation,
&index_part_with_deleted_at,
)
@@ -978,7 +962,6 @@ impl RemoteTimelineClient {
remote_layer_path(
&self.tenant_id,
&self.timeline_id,
meta.shard,
&file_name,
meta.generation,
)
@@ -1027,12 +1010,7 @@ impl RemoteTimelineClient {
.unwrap_or(
// No generation-suffixed indices, assume we are dealing with
// a legacy index.
remote_index_path(
&self.tenant_id,
&self.timeline_id,
self.get_shard_index(),
Generation::none(),
),
remote_index_path(&self.tenant_id, &self.timeline_id, Generation::none()),
);
let remaining_layers: Vec<RemotePath> = remaining
@@ -1241,7 +1219,6 @@ impl RemoteTimelineClient {
&self.storage_impl,
&self.tenant_id,
&self.timeline_id,
self.get_shard_index(),
self.generation,
index_part,
)
@@ -1550,14 +1527,12 @@ pub fn remote_timeline_path(tenant_id: &TenantId, timeline_id: &TimelineId) -> R
pub fn remote_layer_path(
tenant_id: &TenantId,
timeline_id: &TimelineId,
shard: ShardIndex,
layer_file_name: &LayerFileName,
generation: Generation,
) -> RemotePath {
// Generation-aware key format
let path = format!(
"tenants/{tenant_id}{0}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{1}{2}",
shard.get_suffix(),
"tenants/{tenant_id}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{0}{1}",
layer_file_name.file_name(),
generation.get_suffix()
);
@@ -1575,12 +1550,10 @@ pub fn remote_initdb_archive_path(tenant_id: &TenantId, timeline_id: &TimelineId
pub fn remote_index_path(
tenant_id: &TenantId,
timeline_id: &TimelineId,
shard: ShardIndex,
generation: Generation,
) -> RemotePath {
RemotePath::from_string(&format!(
"tenants/{tenant_id}{0}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{1}{2}",
shard.get_suffix(),
"tenants/{tenant_id}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{0}{1}",
IndexPart::FILE_NAME,
generation.get_suffix()
))
@@ -1805,7 +1778,6 @@ mod tests {
println!("remote_timeline_dir: {remote_timeline_dir}");
let generation = harness.generation;
let shard = harness.shard;
// Create a couple of dummy files, schedule upload for them
@@ -1822,7 +1794,7 @@ mod tests {
harness.conf,
&timeline,
name,
LayerFileMetadata::new(contents.len() as u64, generation, shard),
LayerFileMetadata::new(contents.len() as u64, generation),
)
}).collect::<Vec<_>>();
@@ -1971,7 +1943,7 @@ mod tests {
harness.conf,
&timeline,
layer_file_name_1.clone(),
LayerFileMetadata::new(content_1.len() as u64, harness.generation, harness.shard),
LayerFileMetadata::new(content_1.len() as u64, harness.generation),
);
#[derive(Debug, PartialEq, Clone, Copy)]
@@ -2036,11 +2008,7 @@ mod tests {
assert_eq!(actual_c, expected_c);
}
async fn inject_index_part(
test_state: &TestSetup,
generation: Generation,
shard: ShardIndex,
) -> IndexPart {
async fn inject_index_part(test_state: &TestSetup, generation: Generation) -> IndexPart {
// An empty IndexPart, just sufficient to ensure deserialization will succeed
let example_metadata = TimelineMetadata::example();
let example_index_part = IndexPart::new(
@@ -2061,13 +2029,7 @@ mod tests {
std::fs::create_dir_all(remote_timeline_dir).expect("creating test dir should work");
let index_path = test_state.harness.remote_fs_dir.join(
remote_index_path(
&test_state.harness.tenant_id,
&TIMELINE_ID,
shard,
generation,
)
.get_path(),
remote_index_path(&test_state.harness.tenant_id, &TIMELINE_ID, generation).get_path(),
);
eprintln!("Writing {index_path}");
std::fs::write(&index_path, index_part_bytes).unwrap();
@@ -2104,12 +2066,7 @@ mod tests {
// Simple case: we are in generation N, load the index from generation N - 1
let generation_n = 5;
let injected = inject_index_part(
&test_state,
Generation::new(generation_n - 1),
ShardIndex::unsharded(),
)
.await;
let injected = inject_index_part(&test_state, Generation::new(generation_n - 1)).await;
assert_got_index_part(&test_state, Generation::new(generation_n), &injected).await;
@@ -2127,34 +2084,22 @@ mod tests {
// A generation-less IndexPart exists in the bucket, we should find it
let generation_n = 5;
let injected_none =
inject_index_part(&test_state, Generation::none(), ShardIndex::unsharded()).await;
let injected_none = inject_index_part(&test_state, Generation::none()).await;
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_none).await;
// If a more recent-than-none generation exists, we should prefer to load that
let injected_1 =
inject_index_part(&test_state, Generation::new(1), ShardIndex::unsharded()).await;
let injected_1 = inject_index_part(&test_state, Generation::new(1)).await;
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_1).await;
// If a more-recent-than-me generation exists, we should ignore it.
let _injected_10 =
inject_index_part(&test_state, Generation::new(10), ShardIndex::unsharded()).await;
let _injected_10 = inject_index_part(&test_state, Generation::new(10)).await;
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_1).await;
// If a directly previous generation exists, _and_ an index exists in my own
// generation, I should prefer my own generation.
let _injected_prev = inject_index_part(
&test_state,
Generation::new(generation_n - 1),
ShardIndex::unsharded(),
)
.await;
let injected_current = inject_index_part(
&test_state,
Generation::new(generation_n),
ShardIndex::unsharded(),
)
.await;
let _injected_prev =
inject_index_part(&test_state, Generation::new(generation_n - 1)).await;
let injected_current = inject_index_part(&test_state, Generation::new(generation_n)).await;
assert_got_index_part(
&test_state,
Generation::new(generation_n),

View File

@@ -9,7 +9,6 @@ use std::time::Duration;
use anyhow::{anyhow, Context};
use camino::Utf8Path;
use pageserver_api::shard::ShardIndex;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
@@ -54,7 +53,6 @@ pub async fn download_layer_file<'a>(
let remote_path = remote_layer_path(
&tenant_id,
&timeline_id,
layer_metadata.shard,
layer_file_name,
layer_metadata.generation,
);
@@ -215,11 +213,10 @@ async fn do_download_index_part(
storage: &GenericRemoteStorage,
tenant_id: &TenantId,
timeline_id: &TimelineId,
shard: ShardIndex,
index_generation: Generation,
cancel: CancellationToken,
) -> Result<IndexPart, DownloadError> {
let remote_path = remote_index_path(tenant_id, timeline_id, shard, index_generation);
let remote_path = remote_index_path(tenant_id, timeline_id, index_generation);
let index_part_bytes = download_retry_forever(
|| async {
@@ -257,7 +254,6 @@ pub(super) async fn download_index_part(
storage: &GenericRemoteStorage,
tenant_id: &TenantId,
timeline_id: &TimelineId,
shard: ShardIndex,
my_generation: Generation,
cancel: CancellationToken,
) -> Result<IndexPart, DownloadError> {
@@ -265,15 +261,8 @@ pub(super) async fn download_index_part(
if my_generation.is_none() {
// Operating without generations: just fetch the generation-less path
return do_download_index_part(
storage,
tenant_id,
timeline_id,
shard,
my_generation,
cancel,
)
.await;
return do_download_index_part(storage, tenant_id, timeline_id, my_generation, cancel)
.await;
}
// Stale case: If we were intentionally attached in a stale generation, there may already be a remote
@@ -284,7 +273,6 @@ pub(super) async fn download_index_part(
storage,
tenant_id,
timeline_id,
shard,
my_generation,
cancel.clone(),
)
@@ -312,7 +300,6 @@ pub(super) async fn download_index_part(
storage,
tenant_id,
timeline_id,
shard,
my_generation.previous(),
cancel.clone(),
)
@@ -333,9 +320,8 @@ pub(super) async fn download_index_part(
}
// General case/fallback: if there is no index at my_generation or prev_generation, then list all index_part.json
// objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent
// to constructing a full index path with no generation, because the generation is a suffix.
let index_prefix = remote_index_path(tenant_id, timeline_id, shard, Generation::none());
// objects, and select the highest one with a generation <= my_generation.
let index_prefix = remote_index_path(tenant_id, timeline_id, Generation::none());
let indices = backoff::retry(
|| async { storage.list_files(Some(&index_prefix)).await },
|_| false,
@@ -361,21 +347,14 @@ pub(super) async fn download_index_part(
match max_previous_generation {
Some(g) => {
tracing::debug!("Found index_part in generation {g:?}");
do_download_index_part(storage, tenant_id, timeline_id, shard, g, cancel).await
do_download_index_part(storage, tenant_id, timeline_id, g, cancel).await
}
None => {
// Migration from legacy pre-generation state: we have a generation but no prior
// attached pageservers did. Try to load from a no-generation path.
tracing::info!("No index_part.json* found");
do_download_index_part(
storage,
tenant_id,
timeline_id,
shard,
Generation::none(),
cancel,
)
.await
do_download_index_part(storage, tenant_id, timeline_id, Generation::none(), cancel)
.await
}
}
}

View File

@@ -12,7 +12,6 @@ use crate::tenant::metadata::TimelineMetadata;
use crate::tenant::storage_layer::LayerFileName;
use crate::tenant::upload_queue::UploadQueueInitialized;
use crate::tenant::Generation;
use pageserver_api::shard::ShardIndex;
use utils::lsn::Lsn;
@@ -26,8 +25,6 @@ pub struct LayerFileMetadata {
file_size: u64,
pub(crate) generation: Generation,
pub(crate) shard: ShardIndex,
}
impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
@@ -35,17 +32,15 @@ impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
LayerFileMetadata {
file_size: other.file_size,
generation: other.generation,
shard: other.shard,
}
}
}
impl LayerFileMetadata {
pub fn new(file_size: u64, generation: Generation, shard: ShardIndex) -> Self {
pub fn new(file_size: u64, generation: Generation) -> Self {
LayerFileMetadata {
file_size,
generation,
shard,
}
}
@@ -166,10 +161,6 @@ pub struct IndexLayerMetadata {
#[serde(default = "Generation::none")]
#[serde(skip_serializing_if = "Generation::is_none")]
pub generation: Generation,
#[serde(default = "ShardIndex::unsharded")]
#[serde(skip_serializing_if = "ShardIndex::is_unsharded")]
pub shard: ShardIndex,
}
impl From<LayerFileMetadata> for IndexLayerMetadata {
@@ -177,7 +168,6 @@ impl From<LayerFileMetadata> for IndexLayerMetadata {
IndexLayerMetadata {
file_size: other.file_size,
generation: other.generation,
shard: other.shard,
}
}
}
@@ -205,15 +195,13 @@ mod tests {
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
@@ -245,15 +233,13 @@ mod tests {
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
@@ -286,15 +272,13 @@ mod tests {
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
@@ -370,21 +354,19 @@ mod tests {
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
generation: Generation::none(),
shard: ShardIndex::unsharded()
generation: Generation::none()
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
metadata: TimelineMetadata::from_bytes(&[113,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]).unwrap(),
deleted_at: Some(chrono::NaiveDateTime::parse_from_str(
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap()),
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap())
};
let part = IndexPart::from_s3_bytes(example.as_bytes()).unwrap();

View File

@@ -4,7 +4,6 @@ use anyhow::{bail, Context};
use bytes::Bytes;
use camino::Utf8Path;
use fail::fail_point;
use pageserver_api::shard::ShardIndex;
use std::io::ErrorKind;
use tokio::fs;
@@ -27,7 +26,6 @@ pub(super) async fn upload_index_part<'a>(
storage: &'a GenericRemoteStorage,
tenant_id: &TenantId,
timeline_id: &TimelineId,
shard: ShardIndex,
generation: Generation,
index_part: &'a IndexPart,
) -> anyhow::Result<()> {
@@ -44,7 +42,7 @@ pub(super) async fn upload_index_part<'a>(
let index_part_size = index_part_bytes.len();
let index_part_bytes = tokio::io::BufReader::new(std::io::Cursor::new(index_part_bytes));
let remote_path = remote_index_path(tenant_id, timeline_id, shard, generation);
let remote_path = remote_index_path(tenant_id, timeline_id, generation);
storage
.upload_storage_object(Box::new(index_part_bytes), index_part_size, &remote_path)
.await

View File

@@ -3,7 +3,6 @@ use camino::{Utf8Path, Utf8PathBuf};
use pageserver_api::models::{
HistoricLayerInfo, LayerAccessKind, LayerResidenceEventReason, LayerResidenceStatus,
};
use pageserver_api::shard::ShardIndex;
use std::ops::Range;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
@@ -97,7 +96,6 @@ impl Layer {
desc,
None,
metadata.generation,
metadata.shard,
)));
debug_assert!(owner.0.needs_download_blocking().unwrap().is_some());
@@ -138,7 +136,6 @@ impl Layer {
desc,
Some(inner),
metadata.generation,
metadata.shard,
)
}));
@@ -182,7 +179,6 @@ impl Layer {
desc,
Some(inner),
timeline.generation,
timeline.get_shard_index(),
)
}));
@@ -430,15 +426,6 @@ struct LayerInner {
/// For loaded layers (resident or evicted) this comes from [`LayerFileMetadata::generation`],
/// for created layers from [`Timeline::generation`].
generation: Generation,
/// The shard of this Layer.
///
/// For layers created in this process, this will always be the [`ShardIndex`] of the
/// current `ShardIdentity`` (TODO: add link once it's introduced).
///
/// For loaded layers, this may be some other value if the tenant has undergone
/// a shard split since the layer was originally written.
shard: ShardIndex,
}
impl std::fmt::Display for LayerInner {
@@ -472,9 +459,9 @@ impl Drop for LayerInner {
let path = std::mem::take(&mut self.path);
let file_name = self.layer_desc().filename();
let gen = self.generation;
let file_size = self.layer_desc().file_size;
let timeline = self.timeline.clone();
let meta = self.metadata();
crate::task_mgr::BACKGROUND_RUNTIME.spawn_blocking(move || {
let _g = span.entered();
@@ -502,7 +489,7 @@ impl Drop for LayerInner {
timeline.metrics.resident_physical_size_sub(file_size);
}
if let Some(remote_client) = timeline.remote_client.as_ref() {
let res = remote_client.schedule_deletion_of_unlinked(vec![(file_name, meta)]);
let res = remote_client.schedule_deletion_of_unlinked(vec![(file_name, gen)]);
if let Err(e) = res {
// test_timeline_deletion_with_files_stuck_in_upload_queue is good at
@@ -536,7 +523,6 @@ impl LayerInner {
desc: PersistentLayerDesc,
downloaded: Option<Arc<DownloadedLayer>>,
generation: Generation,
shard: ShardIndex,
) -> Self {
let path = conf
.timeline_path(&timeline.tenant_id, &timeline.timeline_id)
@@ -564,7 +550,6 @@ impl LayerInner {
status: tokio::sync::broadcast::channel(1).0,
consecutive_failures: AtomicUsize::new(0),
generation,
shard,
}
}
@@ -1092,7 +1077,7 @@ impl LayerInner {
}
fn metadata(&self) -> LayerFileMetadata {
LayerFileMetadata::new(self.desc.file_size, self.generation, self.shard)
LayerFileMetadata::new(self.desc.file_size, self.generation)
}
}

View File

@@ -62,7 +62,6 @@ use crate::pgdatadir_mapping::{is_rel_fsm_block_key, is_rel_vm_block_key};
use crate::pgdatadir_mapping::{BlockNumber, CalculateLogicalSizeError};
use crate::tenant::config::{EvictionPolicy, TenantConfOpt};
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::ShardIndex;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::to_pg_timestamp;
@@ -1598,7 +1597,6 @@ impl Timeline {
// Copy to move into the task we're about to spawn
let generation = self.generation;
let shard = self.get_shard_index();
let this = self.myself.upgrade().expect("&self method holds the arc");
let (loaded_layers, needs_cleanup, total_physical_size) = tokio::task::spawn_blocking({
@@ -1647,7 +1645,6 @@ impl Timeline {
index_part.as_ref(),
disk_consistent_lsn,
generation,
shard,
);
let mut loaded_layers = Vec::new();
@@ -4367,11 +4364,6 @@ impl Timeline {
resident_layers,
}
}
pub(crate) fn get_shard_index(&self) -> ShardIndex {
// TODO: carry this on the struct
ShardIndex::unsharded()
}
}
type TraversalPathItem = (

View File

@@ -13,7 +13,6 @@ use crate::{
};
use anyhow::Context;
use camino::Utf8Path;
use pageserver_api::shard::ShardIndex;
use std::{collections::HashMap, str::FromStr};
use utils::lsn::Lsn;
@@ -108,7 +107,6 @@ pub(super) fn reconcile(
index_part: Option<&IndexPart>,
disk_consistent_lsn: Lsn,
generation: Generation,
shard: ShardIndex,
) -> Vec<(LayerFileName, Result<Decision, DismissedLayer>)> {
use Decision::*;
@@ -120,13 +118,10 @@ pub(super) fn reconcile(
.map(|(name, file_size)| {
(
name,
// The generation and shard here will be corrected to match IndexPart in the merge below, unless
// The generation here will be corrected to match IndexPart in the merge below, unless
// it is not in IndexPart, in which case using our current generation makes sense
// because it will be uploaded in this generation.
(
Some(LayerFileMetadata::new(file_size, generation, shard)),
None,
),
(Some(LayerFileMetadata::new(file_size, generation)), None),
)
})
.collect::<Collected>();

View File

@@ -1,5 +1,6 @@
use super::storage_layer::LayerFileName;
use super::storage_layer::ResidentLayer;
use super::Generation;
use crate::tenant::metadata::TimelineMetadata;
use crate::tenant::remote_timeline_client::index::IndexPart;
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
@@ -14,9 +15,6 @@ use utils::lsn::AtomicLsn;
use std::sync::atomic::AtomicU32;
use utils::lsn::Lsn;
#[cfg(feature = "testing")]
use utils::generation::Generation;
// clippy warns that Uninitialized is much smaller than Initialized, which wastes
// memory for Uninitialized variants. Doesn't matter in practice, there are not
// that many upload queues in a running pageserver, and most of them are initialized
@@ -234,7 +232,7 @@ pub(crate) struct UploadTask {
/// for timeline deletion, which skips this queue and goes directly to DeletionQueue.
#[derive(Debug)]
pub(crate) struct Delete {
pub(crate) layers: Vec<(LayerFileName, LayerFileMetadata)>,
pub(crate) layers: Vec<(LayerFileName, Generation)>,
}
#[derive(Debug)]

View File

@@ -21,7 +21,6 @@
#include "storage/buf_internals.h"
#include "storage/lwlock.h"
#include "storage/ipc.h"
#include "storage/pg_shmem.h"
#include "c.h"
#include "postmaster/interrupt.h"
@@ -88,12 +87,6 @@ bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) =
static bool pageserver_flush(void);
static void pageserver_disconnect(void);
static bool
PagestoreShmemIsValid()
{
return pagestore_shared && UsedShmemSegAddr;
}
static bool
CheckPageserverConnstring(char **newval, void **extra, GucSource source)
{
@@ -103,7 +96,7 @@ CheckPageserverConnstring(char **newval, void **extra, GucSource source)
static void
AssignPageserverConnstring(const char *newval, void *extra)
{
if(!PagestoreShmemIsValid())
if(!pagestore_shared)
return;
LWLockAcquire(pagestore_shared->lock, LW_EXCLUSIVE);
strlcpy(pagestore_shared->pageserver_connstring, newval, MAX_PAGESERVER_CONNSTRING_SIZE);
@@ -114,7 +107,7 @@ AssignPageserverConnstring(const char *newval, void *extra)
static bool
CheckConnstringUpdated()
{
if(!PagestoreShmemIsValid())
if(!pagestore_shared)
return false;
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->update_counter);
}
@@ -122,7 +115,7 @@ CheckConnstringUpdated()
static void
ReloadConnstring()
{
if(!PagestoreShmemIsValid())
if(!pagestore_shared)
return;
LWLockAcquire(pagestore_shared->lock, LW_SHARED);
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));

View File

@@ -2,4 +2,3 @@
comment = 'cloud storage for PostgreSQL'
default_version = '1.1'
module_pathname = '$libdir/neon'
relocatable = true

View File

@@ -76,4 +76,3 @@ tokio-util.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
postgres-protocol.workspace = true

View File

@@ -6,7 +6,6 @@ pub use link::LinkAuthError;
use tokio_postgres::config::AuthKeys;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
use crate::stream::Stream;
use crate::{
auth::{self, ClientCredentials},
config::AuthenticationConfig,
@@ -132,7 +131,7 @@ async fn auth_quirks_creds(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -166,7 +165,7 @@ async fn auth_quirks(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -242,7 +241,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
pub async fn authenticate(
&mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,

View File

@@ -6,7 +6,7 @@ use crate::{
console::{self, AuthInfo, ConsoleReqExtra},
proxy::LatencyTimer,
sasl, scram,
stream::{PqStream, Stream},
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {

View File

@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
proxy::LatencyTimer,
stream::{self, Stream},
stream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -12,7 +12,7 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn cleartext_hack(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("cleartext auth flow override is enabled, proceeding");
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
/// Very similar to [`cleartext_hack`], but there's a specific password format.
pub async fn password_hack(
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("project not specified, resorting to the password hack auth flow");

View File

@@ -1,21 +1,16 @@
//! Main authentication flow.
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
sasl, scram,
stream::{PqStream, Stream},
};
use crate::{sasl, scram, stream::PqStream};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
fn first_message(&self) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
@@ -26,14 +21,8 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
}
}
@@ -43,7 +32,7 @@ pub struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
@@ -54,44 +43,37 @@ pub struct CleartextPassword;
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, S, State> {
pub struct AuthFlow<'a, Stream, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
stream: &'a mut PqStream<Stream>,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
pub fn new(stream: &'a mut PqStream<S>) -> Self {
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
self.stream.write_message(&method.first_message()).await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
@@ -141,15 +123,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.authenticate(scram::Exchange::new(secret, rand::random, None))
.await?;
Ok(outcome)

View File

@@ -6,8 +6,6 @@
use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
@@ -67,7 +65,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config: Arc<rustls::ServerConfig> = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -91,22 +89,16 @@ async fn main() -> anyhow::Result<()> {
))?
.into_iter()
.map(rustls::Certificate)
.collect_vec()
.collect()
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder()
rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
.into()
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -121,7 +113,6 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -143,7 +134,6 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -169,7 +159,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
handle_client(dest_suffix, tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -217,7 +207,6 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
@@ -242,11 +231,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(raw.upgrade(tls_config).await?),
tls_server_end_point,
})
Ok(raw.upgrade(tls_config).await?)
}
unexpected => {
info!(
@@ -261,10 +246,9 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
let tls_stream = ssl_handshake(stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -1,15 +1,12 @@
use crate::auth;
use anyhow::{bail, ensure, Context, Ok};
use rustls::{sign, Certificate, PrivateKey};
use sha2::{Digest, Sha256};
use rustls::sign;
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
@@ -30,7 +27,6 @@ pub struct MetricCollectionConfig {
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
@@ -56,7 +52,7 @@ pub fn configure_tls(
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
cert_resolver.add_cert(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -68,7 +64,7 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
cert_resolver.add_cert(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
@@ -80,97 +76,35 @@ pub fn configure_tls(
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
let config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone())
.with_cert_resolver(Arc::new(cert_resolver))
.into();
Ok(TlsConfig {
config,
common_names: Some(common_names),
cert_resolver,
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(&cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] =
Sha256::new().chain_update(&cert.0).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
struct CertResolver {
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
default: Option<Arc<rustls::sign::CertifiedKey>>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
fn new() -> Self {
Self {
certs: HashMap::new(),
default: None,
}
}
fn add_cert_path(
fn add_cert(
&mut self,
key_path: &str,
cert_path: &str,
@@ -186,65 +120,57 @@ impl CertResolver {
keys.pop().map(rustls::PrivateKey).unwrap()
};
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.with_context(|| {
format!(
.context(format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
)
})?
))?
.into_iter()
.map(rustls::Certificate)
.collect()
};
self.add_cert(priv_key, cert_chain, is_default)
}
let common_name = {
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
.context(format!(
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let common_name = pem.parse_x509()?.subject().to_string();
pub fn add_cert(
&mut self,
priv_key: PrivateKey,
cert_chain: Vec<Certificate>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
}
.context("Failed to parse common name from certificate")?;
.context(format!(
"Failed to parse common name from certificate at '{cert_path}'."
))?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
self.default = Some(cert.clone());
}
self.certs.insert(common_name, (cert, tls_server_end_point));
self.certs.insert(common_name, cert);
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
@@ -252,24 +178,15 @@ impl CertResolver {
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello,
_client_hello: rustls::server::ClientHello,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
if let Some(mut sni_name) = _client_hello.server_name() {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());

View File

@@ -470,17 +470,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
}
}
_ => bail!(ERR_PROTO_VIOLATION),
@@ -885,7 +875,7 @@ pub async fn proxy_pass(
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<Stream<S>>,
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
/// KV-dictionary with PostgreSQL connection params.
@@ -899,7 +889,7 @@ struct Client<'a, S> {
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<Stream<S>>,
stream: PqStream<S>,
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,

View File

@@ -1,23 +1,19 @@
//! A group of high-level tests for connection establishing logic and auth.
mod mitm;
//!
use super::*;
use crate::auth::backend::TestBackend;
use crate::auth::ClientCredentials;
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
use tokio_postgres_rustls::MakeRustlsConnect;
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
hostname: &str,
common_name: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
@@ -25,15 +21,7 @@ fn generate_certs(
params
})?;
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
@@ -49,14 +37,7 @@ struct ClientConfig<'a> {
impl ClientConfig<'_> {
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
self,
) -> anyhow::Result<
impl tokio_postgres::tls::TlsConnect<
S,
Error = impl std::fmt::Debug,
Future = impl Send,
Stream = RustlsStream<S>,
>,
> {
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
@@ -68,24 +49,20 @@ fn generate_tls_config<'a>(
hostname: &'a str,
common_name: &'a str,
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
let (ca, cert, key) = generate_certs(hostname, common_name)?;
let (ca, cert, key) = generate_certs(hostname)?;
let tls_config = {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone())?
.with_single_cert(vec![cert], key)?
.into();
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = Some(cert_resolver.get_common_names());
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
TlsConfig {
config,
common_names,
cert_resolver: Arc::new(cert_resolver),
}
};
@@ -276,7 +253,6 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
.user("user")
.dbname("db")
.password(password)
@@ -287,30 +263,6 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
proxy.await?
}
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
#[tokio::test]
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);

View File

@@ -1,257 +0,0 @@
//! Man-in-the-middle tests
//!
//! Channel binding should prevent a proxy server
//! - that has access to create valid certificates -
//! from controlling the TLS connection.
use std::fmt::Debug;
use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};
enum Intercept {
None,
Methods,
SASLResponse,
}
async fn proxy_mitm(
intercept: Intercept,
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
let (end_server1, client1) = tokio::io::duplex(1024);
let (server2, end_client2) = tokio::io::duplex(1024);
let (client_config1, server_config1) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
let (client_config2, server_config2) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
assert!(buf.is_empty());
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(startup.iter(), &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
loop {
tokio::select! {
message = end_server.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL and return only SCRAM-SHA-256 ;)
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
continue;
}
end_client.send(message).await.unwrap()
}
_ => break,
}
}
message = end_client.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
let sasl_message = &message[1+4+19+4..];
let mut new_message = b"n,,".to_vec();
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
continue;
}
end_server.send(message).await.unwrap()
}
_ => break,
}
}
else => { break }
}
}
});
(end_server1, end_client2, client_config1, server_config2)
}
/// taken from tokio-postgres
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
T::Error: Debug,
{
let mut buf = BytesMut::new();
frontend::ssl_request(&mut buf);
stream.write_all(&buf).await.unwrap();
let mut buf = [0];
stream.read_exact(&mut buf).await.unwrap();
if buf[0] != b'S' {
panic!("ssl not supported by server");
}
tls.connect(stream).await.unwrap()
}
struct PgFrame;
impl Decoder for PgFrame {
type Item = Bytes;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 5 {
src.reserve(5 - src.len());
return Ok(None);
}
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
if src.len() < len {
src.reserve(len - src.len());
return Ok(None);
}
Ok(Some(src.split_to(len).freeze()))
}
}
impl Encoder<Bytes> for PgFrame {
type Error = io::Error;
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(&item);
Ok(())
}
}
/// If the client doesn't support channel bindings, it can be exploited.
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
async fn connect_failure(
intercept: Intercept,
channel_binding: tokio_postgres::config::ChannelBinding,
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(channel_binding)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await
.err()
.context("client shouldn't be able to connect")?;
let _server_err = proxy
.await?
.err()
.context("server shouldn't accept client")?;
Ok(())
}

View File

@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
impl<T: std::fmt::Display> ChannelBinding<T> {
/// Encode channel binding data as base64 for subsequent checks.
pub fn encode<'a, E>(
pub fn encode<E>(
&self,
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
@@ -51,11 +51,12 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
base64::encode(&cbind_input).into()
let msg = format!(
"p={mode},,{data}",
mode = mode,
data = get_cbind_data(mode)?
);
base64::encode(msg).into()
}
})
}
@@ -76,7 +77,7 @@ mod tests {
];
for (cb, input) in cases {
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
}
Ok(())

View File

@@ -22,12 +22,9 @@ pub use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
// TODO: add SCRAM-SHA-256-PLUS
/// A list of supported SCRAM methods.
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
/// Decode base64 into array without any heap allocations
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
@@ -83,11 +80,7 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let mut exchange = Exchange::new(&secret, || NONCE, None);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -5,11 +5,9 @@ use super::messages::{
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
#[derive(Debug)]
struct TlsServerEndPoint;
impl std::fmt::Display for TlsServerEndPoint {
@@ -45,20 +43,20 @@ pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: config::TlsServerEndPoint,
cert_digest: Option<&'a [u8]>,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: config::TlsServerEndPoint,
cert_digest: Option<&'a [u8]>,
) -> Self {
Self {
state: ExchangeState::Initial,
secret,
nonce,
tls_server_end_point,
cert_digest,
}
}
}
@@ -73,14 +71,6 @@ impl sasl::Mechanism for Exchange<'_> {
let client_first_message = ClientFirstMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
// If the flag is set to "y" and the server supports channel
// binding, the server MUST fail authentication
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
&& self.tls_server_end_point.supported()
{
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
}
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
@@ -104,11 +94,10 @@ impl sasl::Mechanism for Exchange<'_> {
let client_final_message = ClientFinalMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => {
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
}
let channel_binding = cbind_flag.encode(|_| {
self.cert_digest
.map(base64::encode)
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
})?;
// This might've been caused by a MITM attack

View File

@@ -1,8 +1,7 @@
use crate::config::TlsServerEndPoint;
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -18,7 +17,7 @@ use tokio_rustls::server::TlsStream;
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
framed: Framed<S>,
}
impl<S> PqStream<S> {
@@ -119,21 +118,19 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
Tls { #[pin] tls: Box<TlsStream<S>> },
}
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
@@ -144,17 +141,7 @@ impl<S> Stream<S> {
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
Stream::Tls {
tls_server_end_point,
..
} => *tls_server_end_point,
Stream::Tls { tls } => tls.get_ref().1.server_name(),
}
}
}
@@ -171,9 +158,12 @@ pub enum StreamUpgradeError {
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
match self {
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
@@ -181,46 +171,50 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
}
}
}

View File

@@ -18,7 +18,6 @@ from datetime import datetime
from functools import cached_property
from itertools import chain, product
from pathlib import Path
from queue import SimpleQueue
from types import TracebackType
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, cast
from urllib.parse import urlparse
@@ -37,11 +36,8 @@ from _pytest.fixtures import FixtureRequest
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import cursor as PgCursor
from psycopg2.extensions import make_dsn, parse_dsn
from pytest_httpserver import HTTPServer
from typing_extensions import Literal
from urllib3.util.retry import Retry
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
from fixtures.broker import NeonBroker
from fixtures.log_helper import log
@@ -1009,68 +1005,6 @@ def neon_env_builder(
yield builder
@pytest.fixture(scope="function")
def neon_env_and_metrics_server(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
) -> Tuple[NeonEnv, HTTPServer, SimpleQueue[Any]]:
"""
Fixture to create a Neon environment and metrics server.
"""
(host, port) = httpserver_listen_address
metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
# this should be Union[str, Tuple[List[Any], bool]], but it will make unpacking much more verbose
uploads: SimpleQueue[Any] = SimpleQueue()
def metrics_handler(request: Request) -> Response:
if request.json is None:
return Response(status=400)
events = request.json["events"]
is_last = request.headers["pageserver-metrics-last-upload-in-batch"]
assert is_last in ["true", "false"]
uploads.put((events, is_last == "true"))
return Response(status=200)
# Require collecting metrics frequently, since we change
# the timeline and want something to be logged about it.
#
# Disable time-based pitr, we will use the manual GC calls
# to trigger remote storage operations in a controlled way
neon_env_builder.pageserver_config_override = f"""
metric_collection_interval="1s"
metric_collection_endpoint="{metric_collection_endpoint}"
cached_metric_collection_interval="0s"
synthetic_size_calculation_interval="3s"
"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}")
# mock http server that returns OK for the metrics
httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(
metrics_handler
)
# spin up neon, after http server is ready
env = neon_env_builder.init_start(initial_tenant_conf={"pitr_interval": "0 sec"})
# httpserver is shut down before pageserver during passing run
env.pageserver.allowed_errors.append(".*metrics endpoint refused the sent metrics*")
# we have a fast rate of calculation, these can happen at shutdown
env.pageserver.allowed_errors.append(
".*synthetic_size_worker:calculate_synthetic_size.*:gather_size_inputs.*: failed to calculate logical size at .*: cancelled.*"
)
env.pageserver.allowed_errors.append(
".*synthetic_size_worker: failed to calculate synthetic size for tenant .*: failed to calculate some logical_sizes"
)
return (env, httpserver, uploads)
@dataclass
class PageserverPort:
pg: int
@@ -1638,7 +1572,7 @@ class NeonAttachmentService:
self.running = False
return self
def attach_hook_issue(self, tenant_id: TenantId, pageserver_id: int) -> int:
def attach_hook(self, tenant_id: TenantId, pageserver_id: int) -> int:
response = requests.post(
f"{self.env.control_plane_api}/attach-hook",
json={"tenant_id": str(tenant_id), "node_id": pageserver_id},
@@ -1648,13 +1582,6 @@ class NeonAttachmentService:
assert isinstance(gen, int)
return gen
def attach_hook_drop(self, tenant_id: TenantId):
response = requests.post(
f"{self.env.control_plane_api}/attach-hook",
json={"tenant_id": str(tenant_id), "node_id": None},
)
response.raise_for_status()
def __enter__(self) -> "NeonAttachmentService":
return self
@@ -1854,20 +1781,13 @@ class NeonPageserver(PgProtocol):
to call into the pageserver HTTP client.
"""
if self.env.attachment_service is not None:
generation = self.env.attachment_service.attach_hook_issue(tenant_id, self.id)
generation = self.env.attachment_service.attach_hook(tenant_id, self.id)
else:
generation = None
client = self.http_client()
return client.tenant_attach(tenant_id, config, config_null, generation=generation)
def tenant_detach(self, tenant_id: TenantId):
if self.env.attachment_service is not None:
self.env.attachment_service.attach_hook_drop(tenant_id)
client = self.http_client()
return client.tenant_detach(tenant_id)
def append_pageserver_param_overrides(
params_to_update: List[str],

View File

@@ -1,8 +1,7 @@
import os
import shutil
from contextlib import closing
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List
import pytest
from fixtures.log_helper import log
@@ -15,33 +14,62 @@ from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
# use neon_env_builder_local fixture to override the default neon_env_builder fixture
# and use a test-specific pg_install instead of shared one
# Check that the extension is not already in the share_dir_path_ext
# if it is, skip the test
#
# After the test is done, cleanup the control file and the extension directory
@pytest.fixture(scope="function")
def neon_env_builder_local(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_distrib_dir: Path,
pg_version: PgVersion,
) -> NeonEnvBuilder:
test_local_pginstall = test_output_dir / "pg_install"
log.info(f"copy {pg_distrib_dir} to {test_local_pginstall}")
shutil.copytree(
pg_distrib_dir / pg_version.v_prefixed, test_local_pginstall / pg_version.v_prefixed
)
def ext_file_cleanup(pg_bin):
out = pg_bin.run_capture("pg_config --sharedir".split())
share_dir_path = Path(f"{out}.stdout").read_text().strip()
log.info(f"share_dir_path: {share_dir_path}")
share_dir_path_ext = os.path.join(share_dir_path, "extension")
neon_env_builder.pg_distrib_dir = test_local_pginstall
log.info(f"local neon_env_builder.pg_distrib_dir: {neon_env_builder.pg_distrib_dir}")
log.info(f"share_dir_path_ext: {share_dir_path_ext}")
return neon_env_builder
# if file is already in the share_dir_path_ext, skip the test
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
log.info("anon.control is already in the share_dir_path_ext, skipping the test")
yield False
return
else:
yield True
# cleanup the control file
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
os.unlink(os.path.join(share_dir_path_ext, "anon.control"))
log.info("anon.control was removed from the share_dir_path_ext")
# remove the extension directory recursively
if os.path.isdir(os.path.join(share_dir_path_ext, "anon")):
directories_to_clean: List[Path] = []
for f in Path(os.path.join(share_dir_path_ext, "anon")).iterdir():
if f.is_file():
log.info(f"Removing file {f}")
f.unlink()
elif f.is_dir():
directories_to_clean.append(f)
for directory_to_clean in reversed(directories_to_clean):
if not os.listdir(directory_to_clean):
log.info(f"Removing empty directory {directory_to_clean}")
directory_to_clean.rmdir()
os.rmdir(os.path.join(share_dir_path_ext, "anon"))
log.info("anon directory was removed from the share_dir_path_ext")
def test_remote_extensions(
httpserver: HTTPServer,
neon_env_builder_local: NeonEnvBuilder,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
pg_version,
ext_file_cleanup,
):
if ext_file_cleanup is False:
log.info("test_remote_extensions skipped")
return
if pg_version == PgVersion.V16:
pytest.skip("TODO: PG16 extension building")
@@ -51,8 +79,7 @@ def test_remote_extensions(
(host, port) = httpserver_listen_address
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
build_tag = os.environ.get("BUILD_TAG", "latest")
archive_path = f"{build_tag}/v{pg_version}/extensions/anon.tar.zst"
archive_path = f"latest/v{pg_version}/extensions/anon.tar.zst"
def endpoint_handler_build_tag(request: Request) -> Response:
log.info(f"request: {request}")
@@ -61,7 +88,6 @@ def test_remote_extensions(
file_path = f"test_runner/regress/data/extension_test/5670669815/v{pg_version}/extensions/anon.tar.zst"
file_size = os.path.getsize(file_path)
fh = open(file_path, "rb")
return Response(
fh,
mimetype="application/octet-stream",
@@ -78,10 +104,12 @@ def test_remote_extensions(
# Start a compute node with remote_extension spec
# and check that it can download the extensions and use them to CREATE EXTENSION.
env = neon_env_builder_local.init_start()
env.neon_cli.create_branch("test_remote_extensions")
env = neon_env_builder.init_start()
tenant_id, _ = env.neon_cli.create_tenant()
env.neon_cli.create_timeline("test_remote_extensions", tenant_id=tenant_id)
endpoint = env.endpoints.create(
"test_remote_extensions",
tenant_id=tenant_id,
config_lines=["log_min_messages=debug3"],
)

View File

@@ -282,7 +282,7 @@ def test_deferred_deletion(neon_env_builder: NeonEnvBuilder):
# Now advance the generation in the control plane: subsequent validations
# from the running pageserver will fail. No more deletions should happen.
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
generate_uploads_and_deletions(env, init=False)
assert_deletion_queue(ps_http, lambda n: n > 0)
@@ -397,7 +397,7 @@ def test_deletion_queue_recovery(
if keep_attachment == KeepAttachment.LOSE:
some_other_pageserver = 101010
assert env.attachment_service is not None
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
env.pageserver.start()

View File

@@ -3,21 +3,75 @@ import time
from dataclasses import dataclass
from pathlib import Path
from queue import SimpleQueue
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, Set
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
wait_for_last_flush_lsn,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.types import TenantId, TimelineId
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
# TODO: collect all of the env setup *AFTER* removal of RemoteStorageKind.NOOP
def test_metric_collection(
neon_env_and_metrics_server: Tuple[NeonEnv, HTTPServer, SimpleQueue[Any]]
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
):
(env, httpserver, uploads) = neon_env_and_metrics_server
(host, port) = httpserver_listen_address
metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
# this should be Union[str, Tuple[List[Any], bool]], but it will make unpacking much more verbose
uploads: SimpleQueue[Any] = SimpleQueue()
def metrics_handler(request: Request) -> Response:
if request.json is None:
return Response(status=400)
events = request.json["events"]
is_last = request.headers["pageserver-metrics-last-upload-in-batch"]
assert is_last in ["true", "false"]
uploads.put((events, is_last == "true"))
return Response(status=200)
# Require collecting metrics frequently, since we change
# the timeline and want something to be logged about it.
#
# Disable time-based pitr, we will use the manual GC calls
# to trigger remote storage operations in a controlled way
neon_env_builder.pageserver_config_override = f"""
metric_collection_interval="1s"
metric_collection_endpoint="{metric_collection_endpoint}"
cached_metric_collection_interval="0s"
synthetic_size_calculation_interval="3s"
"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}")
# mock http server that returns OK for the metrics
httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(
metrics_handler
)
# spin up neon, after http server is ready
env = neon_env_builder.init_start(initial_tenant_conf={"pitr_interval": "0 sec"})
# httpserver is shut down before pageserver during passing run
env.pageserver.allowed_errors.append(".*metrics endpoint refused the sent metrics*")
# we have a fast rate of calculation, these can happen at shutdown
env.pageserver.allowed_errors.append(
".*synthetic_size_worker:calculate_synthetic_size.*:gather_size_inputs.*: failed to calculate logical size at .*: cancelled.*"
)
env.pageserver.allowed_errors.append(
".*synthetic_size_worker: failed to calculate synthetic size for tenant .*: failed to calculate some logical_sizes"
)
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
@@ -114,11 +168,59 @@ def test_metric_collection(
def test_metric_collection_cleans_up_tempfile(
neon_env_and_metrics_server: Tuple[NeonEnv, HTTPServer, SimpleQueue[Any]]
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
):
(env, httpserver, uploads) = neon_env_and_metrics_server
(host, port) = httpserver_listen_address
metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
# this should be Union[str, Tuple[List[Any], bool]], but it will make unpacking much more verbose
uploads: SimpleQueue[Any] = SimpleQueue()
def metrics_handler(request: Request) -> Response:
if request.json is None:
return Response(status=400)
events = request.json["events"]
is_last = request.headers["pageserver-metrics-last-upload-in-batch"]
assert is_last in ["true", "false"]
uploads.put((events, is_last == "true"))
return Response(status=200)
# Require collecting metrics frequently, since we change
# the timeline and want something to be logged about it.
#
# Disable time-based pitr, we will use the manual GC calls
# to trigger remote storage operations in a controlled way
neon_env_builder.pageserver_config_override = f"""
metric_collection_interval="1s"
metric_collection_endpoint="{metric_collection_endpoint}"
cached_metric_collection_interval="0s"
synthetic_size_calculation_interval="3s"
"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
# mock http server that returns OK for the metrics
httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(
metrics_handler
)
# spin up neon, after http server is ready
env = neon_env_builder.init_start(initial_tenant_conf={"pitr_interval": "0 sec"})
pageserver_http = env.pageserver.http_client()
# httpserver is shut down before pageserver during passing run
env.pageserver.allowed_errors.append(".*metrics endpoint refused the sent metrics*")
# we have a fast rate of calculation, these can happen at shutdown
env.pageserver.allowed_errors.append(
".*synthetic_size_worker:calculate_synthetic_size.*:gather_size_inputs.*: failed to calculate logical size at .*: cancelled.*"
)
env.pageserver.allowed_errors.append(
".*synthetic_size_worker: failed to calculate synthetic size for tenant .*: failed to calculate some logical_sizes"
)
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
endpoint = env.endpoints.create_start("main", tenant_id=tenant_id)
@@ -188,8 +290,6 @@ def test_metric_collection_cleans_up_tempfile(
), "only initial tempfile should had been removed"
assert initially.other.issuperset(later.other), "no other files should had been removed"
httpserver.check()
@dataclass
class PrefixPartitionedFiles:

View File

@@ -336,15 +336,10 @@ def test_live_reconfig_get_evictions_low_residence_duration_metric_threshold(
):
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
env = neon_env_builder.init_start(
initial_tenant_conf={
# disable compaction so that it will not download the layer for repartitioning
"compaction_period": "0s"
}
)
env = neon_env_builder.init_start()
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
(tenant_id, timeline_id) = env.initial_tenant, env.initial_timeline
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
ps_http = env.pageserver.http_client()
def get_metric():