From dbebede7bf5ff0cefd303f266781457b7472d070 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Wed, 15 Jan 2025 15:33:54 +0000 Subject: [PATCH] safekeeper: fan out from single wal reader to multiple shards (#10190) ## Problem Safekeepers currently decode and interpret WAL for each shard separately. This is wasteful in terms of CPU memory usage - we've seen this in profiles. ## Summary of changes Fan-out interpreted WAL to multiple shards. The basic is that wal decoding and interpretation happens in a separate tokio task and senders attach to it. Senders only receive batches concerning their shard and only past the Lsn they've last seen. Fan-out is gated behind the `wal_reader_fanout` safekeeper flag (disabled by default for now). When fan-out is enabled, it might be desirable to control the absolute delta between the current position and a new shard's desired position (i.e. how far behind or ahead a shard may be). `max_delta_for_fanout` is a new optional safekeeper flag which dictates whether to create a new WAL reader or attach to the existing one. By default, this behaviour is disabled. Let's consider enabling it if we spot the need for it in the field. ## Testing Tests passed [here](https://github.com/neondatabase/neon/pull/10301) with wal reader fanout enabled as of https://github.com/neondatabase/neon/pull/10190/commits/34f6a717182c431847bbd5b7828fd0f89027b2be. Related: https://github.com/neondatabase/neon/issues/9337 Epic: https://github.com/neondatabase/neon/issues/9329 --- Cargo.lock | 3 + libs/safekeeper_api/Cargo.toml | 1 + libs/safekeeper_api/src/models.rs | 20 +- libs/wal_decoder/src/models.rs | 2 +- libs/wal_decoder/src/serialized_batch.rs | 8 +- safekeeper/Cargo.toml | 3 + safekeeper/src/bin/safekeeper.rs | 9 + safekeeper/src/http/routes.rs | 2 +- safekeeper/src/lib.rs | 4 + safekeeper/src/metrics.rs | 32 +- safekeeper/src/send_interpreted_wal.rs | 765 ++++++++++++++++-- safekeeper/src/send_wal.rs | 354 ++++++-- safekeeper/src/test_utils.rs | 65 +- safekeeper/src/timeline.rs | 14 +- safekeeper/src/wal_reader_stream.rs | 396 ++++++--- .../tests/walproposer_sim/safekeeper.rs | 2 + 16 files changed, 1410 insertions(+), 270 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0669899617..afe16ff848 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5655,6 +5655,7 @@ dependencies = [ "crc32c", "criterion", "desim", + "env_logger 0.10.2", "fail", "futures", "hex", @@ -5683,6 +5684,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "smallvec", "storage_broker", "strum", "strum_macros", @@ -5709,6 +5711,7 @@ version = "0.1.0" dependencies = [ "anyhow", "const_format", + "pageserver_api", "postgres_ffi", "pq_proto", "serde", diff --git a/libs/safekeeper_api/Cargo.toml b/libs/safekeeper_api/Cargo.toml index 7652c3d413..6b72ace019 100644 --- a/libs/safekeeper_api/Cargo.toml +++ b/libs/safekeeper_api/Cargo.toml @@ -13,3 +13,4 @@ postgres_ffi.workspace = true pq_proto.workspace = true tokio.workspace = true utils.workspace = true +pageserver_api.workspace = true diff --git a/libs/safekeeper_api/src/models.rs b/libs/safekeeper_api/src/models.rs index a6f90154f4..b5fa903820 100644 --- a/libs/safekeeper_api/src/models.rs +++ b/libs/safekeeper_api/src/models.rs @@ -1,5 +1,6 @@ //! Types used in safekeeper http API. Many of them are also reused internally. +use pageserver_api::shard::ShardIdentity; use postgres_ffi::TimestampTz; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; @@ -146,7 +147,13 @@ pub type ConnectionId = u32; /// Serialize is used only for json'ing in API response. Also used internally. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WalSenderState { +pub enum WalSenderState { + Vanilla(VanillaWalSenderState), + Interpreted(InterpretedWalSenderState), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VanillaWalSenderState { pub ttid: TenantTimelineId, pub addr: SocketAddr, pub conn_id: ConnectionId, @@ -155,6 +162,17 @@ pub struct WalSenderState { pub feedback: ReplicationFeedback, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InterpretedWalSenderState { + pub ttid: TenantTimelineId, + pub shard: ShardIdentity, + pub addr: SocketAddr, + pub conn_id: ConnectionId, + // postgres application_name + pub appname: Option, + pub feedback: ReplicationFeedback, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WalReceiverState { /// None means it is recovery initiated by us (this safekeeper). diff --git a/libs/wal_decoder/src/models.rs b/libs/wal_decoder/src/models.rs index 8bfa48faac..c2f9125b21 100644 --- a/libs/wal_decoder/src/models.rs +++ b/libs/wal_decoder/src/models.rs @@ -64,7 +64,7 @@ pub struct InterpretedWalRecords { } /// An interpreted Postgres WAL record, ready to be handled by the pageserver -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct InterpretedWalRecord { /// Optional metadata record - may cause writes to metadata keys /// in the storage engine diff --git a/libs/wal_decoder/src/serialized_batch.rs b/libs/wal_decoder/src/serialized_batch.rs index c70ff05b8e..d76f75f51f 100644 --- a/libs/wal_decoder/src/serialized_batch.rs +++ b/libs/wal_decoder/src/serialized_batch.rs @@ -32,7 +32,7 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); /// relation sizes. In the case of "observed" values, we only need to know /// the key and LSN, so two types of metadata are supported to save on network /// bandwidth. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub enum ValueMeta { Serialized(SerializedValueMeta), Observed(ObservedValueMeta), @@ -79,7 +79,7 @@ impl PartialEq for OrderedValueMeta { impl Eq for OrderedValueMeta {} /// Metadata for a [`Value`] serialized into the batch. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct SerializedValueMeta { pub key: CompactKey, pub lsn: Lsn, @@ -91,14 +91,14 @@ pub struct SerializedValueMeta { } /// Metadata for a [`Value`] observed by the batch -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct ObservedValueMeta { pub key: CompactKey, pub lsn: Lsn, } /// Batch of serialized [`Value`]s. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct SerializedValueBatch { /// [`Value`]s serialized in EphemeralFile's native format, /// ready for disk write by the pageserver diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 3ebb7097f2..0eb511f1cc 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -26,6 +26,7 @@ hex.workspace = true humantime.workspace = true http.workspace = true hyper0.workspace = true +itertools.workspace = true futures.workspace = true once_cell.workspace = true parking_lot.workspace = true @@ -39,6 +40,7 @@ scopeguard.workspace = true reqwest = { workspace = true, features = ["json"] } serde.workspace = true serde_json.workspace = true +smallvec.workspace = true strum.workspace = true strum_macros.workspace = true thiserror.workspace = true @@ -63,6 +65,7 @@ storage_broker.workspace = true tokio-stream.workspace = true utils.workspace = true wal_decoder.workspace = true +env_logger.workspace = true workspace_hack.workspace = true diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index bc7af02185..6cc53e0d23 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -207,6 +207,13 @@ struct Args { /// Also defines interval for eviction retries. #[arg(long, value_parser = humantime::parse_duration, default_value = DEFAULT_EVICTION_MIN_RESIDENT)] eviction_min_resident: Duration, + /// Enable fanning out WAL to different shards from the same reader + #[arg(long)] + wal_reader_fanout: bool, + /// Only fan out the WAL reader if the absoulte delta between the new requested position + /// and the current position of the reader is smaller than this value. + #[arg(long)] + max_delta_for_fanout: Option, } // Like PathBufValueParser, but allows empty string. @@ -370,6 +377,8 @@ async fn main() -> anyhow::Result<()> { control_file_save_interval: args.control_file_save_interval, partial_backup_concurrency: args.partial_backup_concurrency, eviction_min_resident: args.eviction_min_resident, + wal_reader_fanout: args.wal_reader_fanout, + max_delta_for_fanout: args.max_delta_for_fanout, }); // initialize sentry if SENTRY_DSN is provided diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index 5ecde4b125..4b9fb9eb67 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -195,7 +195,7 @@ async fn timeline_status_handler(request: Request) -> Result, } impl SafeKeeperConf { @@ -150,6 +152,8 @@ impl SafeKeeperConf { control_file_save_interval: Duration::from_secs(1), partial_backup_concurrency: 1, eviction_min_resident: Duration::ZERO, + wal_reader_fanout: false, + max_delta_for_fanout: None, } } } diff --git a/safekeeper/src/metrics.rs b/safekeeper/src/metrics.rs index 5883f402c7..3ea9e3d674 100644 --- a/safekeeper/src/metrics.rs +++ b/safekeeper/src/metrics.rs @@ -12,9 +12,9 @@ use metrics::{ pow2_buckets, proto::MetricFamily, register_histogram, register_histogram_vec, register_int_counter, register_int_counter_pair, - register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, Gauge, GaugeVec, - Histogram, HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, - IntGauge, IntGaugeVec, DISK_FSYNC_SECONDS_BUCKETS, + register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, + register_int_gauge_vec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair, + IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, DISK_FSYNC_SECONDS_BUCKETS, }; use once_cell::sync::Lazy; use postgres_ffi::XLogSegNo; @@ -211,6 +211,14 @@ pub static WAL_RECEIVERS: Lazy = Lazy::new(|| { ) .expect("Failed to register safekeeper_wal_receivers") }); +pub static WAL_READERS: Lazy = Lazy::new(|| { + register_int_gauge_vec!( + "safekeeper_wal_readers", + "Number of active WAL readers (may serve pageservers or other safekeepers)", + &["kind", "target"] + ) + .expect("Failed to register safekeeper_wal_receivers") +}); pub static WAL_RECEIVER_QUEUE_DEPTH: Lazy = Lazy::new(|| { // Use powers of two buckets, but add a bucket at 0 and the max queue size to track empty and // full queues respectively. @@ -443,6 +451,7 @@ pub struct FullTimelineInfo { pub timeline_is_active: bool, pub num_computes: u32, pub last_removed_segno: XLogSegNo, + pub interpreted_wal_reader_tasks: usize, pub epoch_start_lsn: Lsn, pub mem_state: TimelineMemState, @@ -472,6 +481,7 @@ pub struct TimelineCollector { disk_usage: GenericGaugeVec, acceptor_term: GenericGaugeVec, written_wal_bytes: GenericGaugeVec, + interpreted_wal_reader_tasks: GenericGaugeVec, written_wal_seconds: GaugeVec, flushed_wal_seconds: GaugeVec, collect_timeline_metrics: Gauge, @@ -670,6 +680,16 @@ impl TimelineCollector { .unwrap(); descs.extend(active_timelines_count.desc().into_iter().cloned()); + let interpreted_wal_reader_tasks = GenericGaugeVec::new( + Opts::new( + "safekeeper_interpreted_wal_reader_tasks", + "Number of active interpreted wal reader tasks, grouped by timeline", + ), + &["tenant_id", "timeline_id"], + ) + .unwrap(); + descs.extend(interpreted_wal_reader_tasks.desc().into_iter().cloned()); + TimelineCollector { global_timelines, descs, @@ -693,6 +713,7 @@ impl TimelineCollector { collect_timeline_metrics, timelines_count, active_timelines_count, + interpreted_wal_reader_tasks, } } } @@ -721,6 +742,7 @@ impl Collector for TimelineCollector { self.disk_usage.reset(); self.acceptor_term.reset(); self.written_wal_bytes.reset(); + self.interpreted_wal_reader_tasks.reset(); self.written_wal_seconds.reset(); self.flushed_wal_seconds.reset(); @@ -782,6 +804,9 @@ impl Collector for TimelineCollector { self.written_wal_bytes .with_label_values(labels) .set(tli.wal_storage.write_wal_bytes); + self.interpreted_wal_reader_tasks + .with_label_values(labels) + .set(tli.interpreted_wal_reader_tasks as u64); self.written_wal_seconds .with_label_values(labels) .set(tli.wal_storage.write_wal_seconds); @@ -834,6 +859,7 @@ impl Collector for TimelineCollector { mfs.extend(self.disk_usage.collect()); mfs.extend(self.acceptor_term.collect()); mfs.extend(self.written_wal_bytes.collect()); + mfs.extend(self.interpreted_wal_reader_tasks.collect()); mfs.extend(self.written_wal_seconds.collect()); mfs.extend(self.flushed_wal_seconds.collect()); diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index a718c16a6a..ea09ce364d 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -1,100 +1,330 @@ +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; use std::time::Duration; -use anyhow::Context; +use anyhow::{anyhow, Context}; +use futures::future::Either; use futures::StreamExt; use pageserver_api::shard::ShardIdentity; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend}; -use postgres_ffi::MAX_SEND_SIZE; +use postgres_ffi::waldecoder::WalDecodeError; use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc::error::SendError; +use tokio::task::JoinHandle; use tokio::time::MissedTickBehavior; +use tracing::{info_span, Instrument}; use utils::lsn::Lsn; use utils::postgres_client::Compression; use utils::postgres_client::InterpretedFormat; use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords}; use wal_decoder::wire_format::ToWireFormat; -use crate::send_wal::EndWatchView; -use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; +use crate::metrics::WAL_READERS; +use crate::send_wal::{EndWatchView, WalSenderGuard}; +use crate::timeline::WalResidentTimeline; +use crate::wal_reader_stream::{StreamingWalReader, WalBytes}; -/// Shard-aware interpreted record sender. -/// This is used for sending WAL to the pageserver. Said WAL -/// is pre-interpreted and filtered for the shard. -pub(crate) struct InterpretedWalSender<'a, IO> { - pub(crate) format: InterpretedFormat, - pub(crate) compression: Option, - pub(crate) pgb: &'a mut PostgresBackend, - pub(crate) wal_stream_builder: WalReaderStreamBuilder, - pub(crate) end_watch_view: EndWatchView, - pub(crate) shard: ShardIdentity, - pub(crate) pg_version: u32, - pub(crate) appname: Option, +/// Identifier used to differentiate between senders of the same +/// shard. +/// +/// In the steady state there's only one, but two pageservers may +/// temporarily have the same shard attached and attempt to ingest +/// WAL for it. See also [`ShardSenderId`]. +#[derive(Hash, Eq, PartialEq, Copy, Clone)] +struct SenderId(u8); + +impl SenderId { + fn first() -> Self { + SenderId(0) + } + + fn next(&self) -> Self { + SenderId(self.0.checked_add(1).expect("few senders")) + } } -struct Batch { +#[derive(Hash, Eq, PartialEq)] +struct ShardSenderId { + shard: ShardIdentity, + sender_id: SenderId, +} + +impl Display for ShardSenderId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug()) + } +} + +impl ShardSenderId { + fn new(shard: ShardIdentity, sender_id: SenderId) -> Self { + ShardSenderId { shard, sender_id } + } + + fn shard(&self) -> ShardIdentity { + self.shard + } +} + +/// Shard-aware fan-out interpreted record reader. +/// Reads WAL from disk, decodes it, intepretets it, and sends +/// it to any [`InterpretedWalSender`] connected to it. +/// Each [`InterpretedWalSender`] corresponds to one shard +/// and gets interpreted records concerning that shard only. +pub(crate) struct InterpretedWalReader { + wal_stream: StreamingWalReader, + shard_senders: HashMap>, + shard_notification_rx: Option>, + state: Arc>, + pg_version: u32, +} + +/// A handle for [`InterpretedWalReader`] which allows for interacting with it +/// when it runs as a separate tokio task. +#[derive(Debug)] +pub(crate) struct InterpretedWalReaderHandle { + join_handle: JoinHandle>, + state: Arc>, + shard_notification_tx: tokio::sync::mpsc::UnboundedSender, +} + +struct ShardSenderState { + sender_id: SenderId, + tx: tokio::sync::mpsc::Sender, + next_record_lsn: Lsn, +} + +/// State of [`InterpretedWalReader`] visible outside of the task running it. +#[derive(Debug)] +pub(crate) enum InterpretedWalReaderState { + Running { current_position: Lsn }, + Done, +} + +pub(crate) struct Batch { wal_end_lsn: Lsn, available_wal_end_lsn: Lsn, records: InterpretedWalRecords, } -impl InterpretedWalSender<'_, IO> { - /// Send interpreted WAL to a receiver. - /// Stops when an error occurs or the receiver is caught up and there's no active compute. - /// - /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? - /// convenience. - pub(crate) async fn run(self) -> Result<(), CopyStreamHandlerEnd> { - let mut wal_position = self.wal_stream_builder.start_pos(); - let mut wal_decoder = - WalStreamDecoder::new(self.wal_stream_builder.start_pos(), self.pg_version); +#[derive(thiserror::Error, Debug)] +pub enum InterpretedWalReaderError { + /// Handler initiates the end of streaming. + #[error("decode error: {0}")] + Decode(#[from] WalDecodeError), + #[error("read or interpret error: {0}")] + ReadOrInterpret(#[from] anyhow::Error), + #[error("wal stream closed")] + WalStreamClosed, +} - let stream = self.wal_stream_builder.build(MAX_SEND_SIZE).await?; - let mut stream = std::pin::pin!(stream); +impl InterpretedWalReaderState { + fn current_position(&self) -> Option { + match self { + InterpretedWalReaderState::Running { + current_position, .. + } => Some(*current_position), + InterpretedWalReaderState::Done => None, + } + } +} - let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1)); - keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - keepalive_ticker.reset(); +pub(crate) struct AttachShardNotification { + shard_id: ShardIdentity, + sender: tokio::sync::mpsc::Sender, + start_pos: Lsn, +} - let (tx, mut rx) = tokio::sync::mpsc::channel::(2); - let shard = vec![self.shard]; +impl InterpretedWalReader { + /// Spawn the reader in a separate tokio task and return a handle + pub(crate) fn spawn( + wal_stream: StreamingWalReader, + start_pos: Lsn, + tx: tokio::sync::mpsc::Sender, + shard: ShardIdentity, + pg_version: u32, + appname: &Option, + ) -> InterpretedWalReaderHandle { + let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running { + current_position: start_pos, + })); + + let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel(); + + let reader = InterpretedWalReader { + wal_stream, + shard_senders: HashMap::from([( + shard, + smallvec::smallvec![ShardSenderState { + sender_id: SenderId::first(), + tx, + next_record_lsn: start_pos, + }], + )]), + shard_notification_rx: Some(shard_notification_rx), + state: state.clone(), + pg_version, + }; + + let metric = WAL_READERS + .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")]) + .unwrap(); + + let join_handle = tokio::task::spawn( + async move { + metric.inc(); + scopeguard::defer! { + metric.dec(); + } + + let res = reader.run_impl(start_pos).await; + if let Err(ref err) = res { + tracing::error!("Task finished with error: {err}"); + } + res + } + .instrument(info_span!("interpreted wal reader")), + ); + + InterpretedWalReaderHandle { + join_handle, + state, + shard_notification_tx, + } + } + + /// Construct the reader without spawning anything + /// Callers should drive the future returned by [`Self::run`]. + pub(crate) fn new( + wal_stream: StreamingWalReader, + start_pos: Lsn, + tx: tokio::sync::mpsc::Sender, + shard: ShardIdentity, + pg_version: u32, + ) -> InterpretedWalReader { + let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running { + current_position: start_pos, + })); + + InterpretedWalReader { + wal_stream, + shard_senders: HashMap::from([( + shard, + smallvec::smallvec![ShardSenderState { + sender_id: SenderId::first(), + tx, + next_record_lsn: start_pos, + }], + )]), + shard_notification_rx: None, + state: state.clone(), + pg_version, + } + } + + /// Entry point for future (polling) based wal reader. + pub(crate) async fn run( + self, + start_pos: Lsn, + appname: &Option, + ) -> Result<(), CopyStreamHandlerEnd> { + let metric = WAL_READERS + .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")]) + .unwrap(); + + metric.inc(); + scopeguard::defer! { + metric.dec(); + } + + let res = self.run_impl(start_pos).await; + if let Err(err) = res { + tracing::error!("Interpreted wal reader encountered error: {err}"); + } else { + tracing::info!("Interpreted wal reader exiting"); + } + + Err(CopyStreamHandlerEnd::Other(anyhow!( + "interpreted wal reader finished" + ))) + } + + /// Send interpreted WAL to one or more [`InterpretedWalSender`]s + /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`] + /// goes out of scope. + async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> { + let defer_state = self.state.clone(); + scopeguard::defer! { + *defer_state.write().unwrap() = InterpretedWalReaderState::Done; + } + + let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version); loop { tokio::select! { - // Get some WAL from the stream and then: decode, interpret and push it down the - // pipeline. - wal = stream.next(), if tx.capacity() > 0 => { - let WalBytes { wal, wal_start_lsn: _, wal_end_lsn, available_wal_end_lsn } = match wal { - Some(some) => some?, - None => { break; } + // Main branch for reading WAL and forwarding it + wal_or_reset = self.wal_stream.next() => { + let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below")); + let WalBytes { + wal, + wal_start_lsn: _, + wal_end_lsn, + available_wal_end_lsn, + } = match wal { + Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?, + None => { + // [`StreamingWalReader::next`] is an endless stream of WAL. + // It shouldn't ever finish unless it panicked or became internally + // inconsistent. + return Result::Err(InterpretedWalReaderError::WalStreamClosed); + } }; - wal_position = wal_end_lsn; wal_decoder.feed_bytes(&wal); - let mut records = Vec::new(); + // Deserialize and interpret WAL records from this batch of WAL. + // Interpreted records for each shard are collected separately. + let shard_ids = self.shard_senders.keys().copied().collect::>(); + let mut records_by_sender: HashMap> = HashMap::new(); let mut max_next_record_lsn = None; - while let Some((next_record_lsn, recdata)) = wal_decoder - .poll_decode() - .with_context(|| "Failed to decode WAL")? + while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()? { assert!(next_record_lsn.is_aligned()); max_next_record_lsn = Some(next_record_lsn); - - // Deserialize and interpret WAL record let interpreted = InterpretedWalRecord::from_bytes_filtered( recdata, - &shard, + &shard_ids, next_record_lsn, self.pg_version, ) - .with_context(|| "Failed to interpret WAL")? - .remove(&self.shard) - .unwrap(); + .with_context(|| "Failed to interpret WAL")?; - if !interpreted.is_empty() { - records.push(interpreted); + for (shard, record) in interpreted { + if record.is_empty() { + continue; + } + + let mut states_iter = self.shard_senders + .get(&shard) + .expect("keys collected above") + .iter() + .filter(|state| record.next_record_lsn > state.next_record_lsn) + .peekable(); + while let Some(state) = states_iter.next() { + let shard_sender_id = ShardSenderId::new(shard, state.sender_id); + + // The most commont case is one sender per shard. Peek and break to avoid the + // clone in that situation. + if states_iter.peek().is_none() { + records_by_sender.entry(shard_sender_id).or_default().push(record); + break; + } else { + records_by_sender.entry(shard_sender_id).or_default().push(record.clone()); + } + } } } @@ -103,20 +333,170 @@ impl InterpretedWalSender<'_, IO> { None => { continue; } }; - let batch = InterpretedWalRecords { - records, - next_record_lsn: Some(max_next_record_lsn), - }; + // Update the current position such that new receivers can decide + // whether to attach to us or spawn a new WAL reader. + match &mut *self.state.write().unwrap() { + InterpretedWalReaderState::Running { current_position, .. } => { + *current_position = max_next_record_lsn; + }, + InterpretedWalReaderState::Done => { + unreachable!() + } + } - tx.send(Batch {wal_end_lsn, available_wal_end_lsn, records: batch}).await.unwrap(); + // Send interpreted records downstream. Anything that has already been seen + // by a shard is filtered out. + let mut shard_senders_to_remove = Vec::new(); + for (shard, states) in &mut self.shard_senders { + for state in states { + if max_next_record_lsn <= state.next_record_lsn { + continue; + } + + let shard_sender_id = ShardSenderId::new(*shard, state.sender_id); + let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default(); + + let batch = InterpretedWalRecords { + records, + next_record_lsn: Some(max_next_record_lsn), + }; + + let res = state.tx.send(Batch { + wal_end_lsn, + available_wal_end_lsn, + records: batch, + }).await; + + if res.is_err() { + shard_senders_to_remove.push(shard_sender_id); + } else { + state.next_record_lsn = max_next_record_lsn; + } + } + } + + // Clean up any shard senders that have dropped out. + // This is inefficient, but such events are rare (connection to PS termination) + // and the number of subscriptions on the same shards very small (only one + // for the steady state). + for to_remove in shard_senders_to_remove { + let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above"); + if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) { + shard_senders.remove(idx); + tracing::info!("Removed shard sender {}", to_remove); + } + + if shard_senders.is_empty() { + self.shard_senders.remove(&to_remove.shard()); + } + } }, - // For a previously interpreted batch, serialize it and push it down the wire. - batch = rx.recv() => { + // Listen for new shards that want to attach to this reader. + // If the reader is not running as a task, then this is not supported + // (see the pending branch below). + notification = match self.shard_notification_rx.as_mut() { + Some(rx) => Either::Left(rx.recv()), + None => Either::Right(std::future::pending()) + } => { + if let Some(n) = notification { + let AttachShardNotification { shard_id, sender, start_pos } = n; + + // Update internal and external state, then reset the WAL stream + // if required. + let senders = self.shard_senders.entry(shard_id).or_default(); + let new_sender_id = match senders.last() { + Some(sender) => sender.sender_id.next(), + None => SenderId::first() + }; + + senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos}); + let current_pos = self.state.read().unwrap().current_position().unwrap(); + if start_pos < current_pos { + self.wal_stream.reset(start_pos).await; + wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version); + } + + tracing::info!( + "Added shard sender {} with start_pos={} current_pos={}", + ShardSenderId::new(shard_id, new_sender_id), start_pos, current_pos + ); + } + } + } + } + } +} + +impl InterpretedWalReaderHandle { + /// Fan-out the reader by attaching a new shard to it + pub(crate) fn fanout( + &self, + shard_id: ShardIdentity, + sender: tokio::sync::mpsc::Sender, + start_pos: Lsn, + ) -> Result<(), SendError> { + self.shard_notification_tx.send(AttachShardNotification { + shard_id, + sender, + start_pos, + }) + } + + /// Get the current WAL position of the reader + pub(crate) fn current_position(&self) -> Option { + self.state.read().unwrap().current_position() + } + + pub(crate) fn abort(&self) { + self.join_handle.abort() + } +} + +impl Drop for InterpretedWalReaderHandle { + fn drop(&mut self) { + tracing::info!("Aborting interpreted wal reader"); + self.abort() + } +} + +pub(crate) struct InterpretedWalSender<'a, IO> { + pub(crate) format: InterpretedFormat, + pub(crate) compression: Option, + pub(crate) appname: Option, + + pub(crate) tli: WalResidentTimeline, + pub(crate) start_lsn: Lsn, + + pub(crate) pgb: &'a mut PostgresBackend, + pub(crate) end_watch_view: EndWatchView, + pub(crate) wal_sender_guard: Arc, + pub(crate) rx: tokio::sync::mpsc::Receiver, +} + +impl InterpretedWalSender<'_, IO> { + /// Send interpreted WAL records over the network. + /// Also manages keep-alives if nothing was sent for a while. + pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> { + let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1)); + keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + keepalive_ticker.reset(); + + let mut wal_position = self.start_lsn; + + loop { + tokio::select! { + batch = self.rx.recv() => { let batch = match batch { Some(b) => b, - None => { break; } + None => { + return Result::Err( + CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early")) + ); + } }; + wal_position = batch.wal_end_lsn; + let buf = batch .records .to_wire(self.format, self.compression) @@ -136,7 +516,21 @@ impl InterpretedWalSender<'_, IO> { })).await?; } // Send a periodic keep alive when the connection has been idle for a while. + // Since we've been idle, also check if we can stop streaming. _ = keepalive_ticker.tick() => { + if let Some(remote_consistent_lsn) = self.wal_sender_guard + .walsenders() + .get_ws_remote_consistent_lsn(self.wal_sender_guard.id()) + { + if self.tli.should_walsender_stop(remote_consistent_lsn).await { + // Stop streaming if the receivers are caught up and + // there's no active compute. This causes the loop in + // [`crate::send_interpreted_wal::InterpretedWalSender::run`] + // to exit and terminate the WAL stream. + break; + } + } + self.pgb .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { wal_end: self.end_watch_view.get().0, @@ -144,14 +538,259 @@ impl InterpretedWalSender<'_, IO> { request_reply: true, })) .await?; - } + }, } } - // The loop above ends when the receiver is caught up and there's no more WAL to send. Err(CopyStreamHandlerEnd::ServerInitiated(format!( "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", self.appname, wal_position, ))) } } +#[cfg(test)] +mod tests { + use std::{collections::HashMap, str::FromStr, time::Duration}; + + use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; + use postgres_ffi::MAX_SEND_SIZE; + use tokio::sync::mpsc::error::TryRecvError; + use utils::{ + id::{NodeId, TenantTimelineId}, + lsn::Lsn, + shard::{ShardCount, ShardNumber}, + }; + + use crate::{ + send_interpreted_wal::{Batch, InterpretedWalReader}, + test_utils::Env, + wal_reader_stream::StreamingWalReader, + }; + + #[tokio::test] + async fn test_interpreted_wal_reader_fanout() { + let _ = env_logger::builder().is_test(true).try_init(); + + const SIZE: usize = 8 * 1024; + const MSG_COUNT: usize = 200; + const PG_VERSION: u32 = 17; + const SHARD_COUNT: u8 = 2; + + let start_lsn = Lsn::from_str("0/149FD18").unwrap(); + let env = Env::new(true).unwrap(); + let tli = env + .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn) + .await + .unwrap(); + + let resident_tli = tli.wal_residence_guard().await.unwrap(); + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) + .await + .unwrap(); + let end_pos = end_watch.get(); + + tracing::info!("Doing first round of reads ..."); + + let streaming_wal_reader = StreamingWalReader::new( + resident_tli, + None, + start_lsn, + end_pos, + end_watch, + MAX_SEND_SIZE, + ); + + let shard_0 = ShardIdentity::new( + ShardNumber(0), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + + let shard_1 = ShardIdentity::new( + ShardNumber(1), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + + let mut shards = HashMap::new(); + + for shard_number in 0..SHARD_COUNT { + let shard_id = ShardIdentity::new( + ShardNumber(shard_number), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); + shards.insert(shard_id, (Some(tx), Some(rx))); + } + + let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap(); + let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap(); + + let handle = InterpretedWalReader::spawn( + streaming_wal_reader, + start_lsn, + shard_0_tx, + shard_0, + PG_VERSION, + &Some("pageserver".to_string()), + ); + + tracing::info!("Reading all WAL with only shard 0 attached ..."); + + let mut shard_0_interpreted_records = Vec::new(); + while let Some(batch) = shard_0_rx.recv().await { + shard_0_interpreted_records.push(batch.records); + if batch.wal_end_lsn == batch.available_wal_end_lsn { + break; + } + } + + let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap(); + let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap(); + + tracing::info!("Attaching shard 1 to the reader at start of WAL"); + handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap(); + + tracing::info!("Reading all WAL with shard 0 and shard 1 attached ..."); + + let mut shard_1_interpreted_records = Vec::new(); + while let Some(batch) = shard_1_rx.recv().await { + shard_1_interpreted_records.push(batch.records); + if batch.wal_end_lsn == batch.available_wal_end_lsn { + break; + } + } + + // This test uses logical messages. Those only go to shard 0. Check that the + // filtering worked and shard 1 did not get any. + assert!(shard_1_interpreted_records + .iter() + .all(|recs| recs.records.is_empty())); + + // Shard 0 should not receive anything more since the reader is + // going through wal that it has already processed. + let res = shard_0_rx.try_recv(); + if let Ok(ref ok) = res { + tracing::error!( + "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}", + ok.wal_end_lsn, + ok.available_wal_end_lsn + ); + } + assert!(matches!(res, Err(TryRecvError::Empty))); + + // Check that the next records lsns received by the two shards match up. + let shard_0_next_lsns = shard_0_interpreted_records + .iter() + .map(|recs| recs.next_record_lsn) + .collect::>(); + let shard_1_next_lsns = shard_1_interpreted_records + .iter() + .map(|recs| recs.next_record_lsn) + .collect::>(); + assert_eq!(shard_0_next_lsns, shard_1_next_lsns); + + handle.abort(); + let mut done = false; + for _ in 0..5 { + if handle.current_position().is_none() { + done = true; + break; + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!(done); + } + + #[tokio::test] + async fn test_interpreted_wal_reader_same_shard_fanout() { + let _ = env_logger::builder().is_test(true).try_init(); + + const SIZE: usize = 8 * 1024; + const MSG_COUNT: usize = 200; + const PG_VERSION: u32 = 17; + const SHARD_COUNT: u8 = 2; + const ATTACHED_SHARDS: u8 = 4; + + let start_lsn = Lsn::from_str("0/149FD18").unwrap(); + let env = Env::new(true).unwrap(); + let tli = env + .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn) + .await + .unwrap(); + + let resident_tli = tli.wal_residence_guard().await.unwrap(); + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) + .await + .unwrap(); + let end_pos = end_watch.get(); + + let streaming_wal_reader = StreamingWalReader::new( + resident_tli, + None, + start_lsn, + end_pos, + end_watch, + MAX_SEND_SIZE, + ); + + let shard_0 = ShardIdentity::new( + ShardNumber(0), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + + let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); + let mut batch_receivers = vec![rx]; + + let handle = InterpretedWalReader::spawn( + streaming_wal_reader, + start_lsn, + tx, + shard_0, + PG_VERSION, + &Some("pageserver".to_string()), + ); + + for _ in 0..(ATTACHED_SHARDS - 1) { + let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); + handle.fanout(shard_0, tx, start_lsn).unwrap(); + batch_receivers.push(rx); + } + + loop { + let batch = batch_receivers.first_mut().unwrap().recv().await.unwrap(); + for rx in batch_receivers.iter_mut().skip(1) { + let other_batch = rx.recv().await.unwrap(); + + assert_eq!(batch.wal_end_lsn, other_batch.wal_end_lsn); + assert_eq!( + batch.available_wal_end_lsn, + other_batch.available_wal_end_lsn + ); + } + + if batch.wal_end_lsn == batch.available_wal_end_lsn { + break; + } + } + + handle.abort(); + let mut done = false; + for _ in 0..5 { + if handle.current_position().is_none() { + done = true; + break; + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!(done); + } +} diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 8463221998..4a4a74a0fd 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -2,16 +2,18 @@ //! with the "START_REPLICATION" message, and registry of walsenders. use crate::handler::SafekeeperPostgresHandler; -use crate::metrics::RECEIVED_PS_FEEDBACKS; +use crate::metrics::{RECEIVED_PS_FEEDBACKS, WAL_READERS}; use crate::receive_wal::WalReceivers; use crate::safekeeper::TermLsn; -use crate::send_interpreted_wal::InterpretedWalSender; +use crate::send_interpreted_wal::{ + Batch, InterpretedWalReader, InterpretedWalReaderHandle, InterpretedWalSender, +}; use crate::timeline::WalResidentTimeline; -use crate::wal_reader_stream::WalReaderStreamBuilder; +use crate::wal_reader_stream::StreamingWalReader; use crate::wal_storage::WalReader; use anyhow::{bail, Context as AnyhowContext}; use bytes::Bytes; -use futures::future::Either; +use futures::FutureExt; use parking_lot::Mutex; use postgres_backend::PostgresBackend; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; @@ -19,16 +21,16 @@ use postgres_ffi::get_current_timestamp; use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; use pq_proto::{BeMessage, WalSndKeepAlive, XLogDataBody}; use safekeeper_api::models::{ - ConnectionId, HotStandbyFeedback, ReplicationFeedback, StandbyFeedback, StandbyReply, - WalSenderState, INVALID_FULL_TRANSACTION_ID, + HotStandbyFeedback, ReplicationFeedback, StandbyFeedback, StandbyReply, + INVALID_FULL_TRANSACTION_ID, }; use safekeeper_api::Term; use tokio::io::{AsyncRead, AsyncWrite}; use utils::failpoint_support; -use utils::id::TenantTimelineId; use utils::pageserver_feedback::PageserverFeedback; use utils::postgres_client::PostgresClientProtocol; +use itertools::Itertools; use std::cmp::{max, min}; use std::net::SocketAddr; use std::sync::Arc; @@ -50,6 +52,12 @@ pub struct WalSenders { walreceivers: Arc, } +pub struct WalSendersTimelineMetricValues { + pub ps_feedback_counter: u64, + pub last_ps_feedback: PageserverFeedback, + pub interpreted_wal_reader_tasks: usize, +} + impl WalSenders { pub fn new(walreceivers: Arc) -> Arc { Arc::new(WalSenders { @@ -60,21 +68,8 @@ impl WalSenders { /// Register new walsender. Returned guard provides access to the slot and /// automatically deregisters in Drop. - fn register( - self: &Arc, - ttid: TenantTimelineId, - addr: SocketAddr, - conn_id: ConnectionId, - appname: Option, - ) -> WalSenderGuard { + fn register(self: &Arc, walsender_state: WalSenderState) -> WalSenderGuard { let slots = &mut self.mutex.lock().slots; - let walsender_state = WalSenderState { - ttid, - addr, - conn_id, - appname, - feedback: ReplicationFeedback::Pageserver(PageserverFeedback::empty()), - }; // find empty slot or create new one let pos = if let Some(pos) = slots.iter().position(|s| s.is_none()) { slots[pos] = Some(walsender_state); @@ -90,9 +85,79 @@ impl WalSenders { } } + fn create_or_update_interpreted_reader< + FUp: FnOnce(&Arc) -> anyhow::Result<()>, + FNew: FnOnce() -> InterpretedWalReaderHandle, + >( + self: &Arc, + id: WalSenderId, + start_pos: Lsn, + max_delta_for_fanout: Option, + update: FUp, + create: FNew, + ) -> anyhow::Result<()> { + let state = &mut self.mutex.lock(); + + let mut selected_interpreted_reader = None; + for slot in state.slots.iter().flatten() { + if let WalSenderState::Interpreted(slot_state) = slot { + if let Some(ref interpreted_reader) = slot_state.interpreted_wal_reader { + let select = match (interpreted_reader.current_position(), max_delta_for_fanout) + { + (Some(pos), Some(max_delta)) => { + let delta = pos.0.abs_diff(start_pos.0); + delta <= max_delta + } + // Reader is not active + (None, _) => false, + // Gating fanout by max delta is disabled. + // Attach to any active reader. + (_, None) => true, + }; + + if select { + selected_interpreted_reader = Some(interpreted_reader.clone()); + break; + } + } + } + } + + let slot = state.get_slot_mut(id); + let slot_state = match slot { + WalSenderState::Interpreted(s) => s, + WalSenderState::Vanilla(_) => unreachable!(), + }; + + let selected_or_new = match selected_interpreted_reader { + Some(selected) => { + update(&selected)?; + selected + } + None => Arc::new(create()), + }; + + slot_state.interpreted_wal_reader = Some(selected_or_new); + + Ok(()) + } + /// Get state of all walsenders. - pub fn get_all(self: &Arc) -> Vec { - self.mutex.lock().slots.iter().flatten().cloned().collect() + pub fn get_all_public(self: &Arc) -> Vec { + self.mutex + .lock() + .slots + .iter() + .flatten() + .map(|state| match state { + WalSenderState::Vanilla(s) => { + safekeeper_api::models::WalSenderState::Vanilla(s.clone()) + } + WalSenderState::Interpreted(s) => { + safekeeper_api::models::WalSenderState::Interpreted(s.public_state.clone()) + } + }) + .collect() } /// Get LSN of the most lagging pageserver receiver. Return None if there are no @@ -103,7 +168,7 @@ impl WalSenders { .slots .iter() .flatten() - .filter_map(|s| match s.feedback { + .filter_map(|s| match s.get_feedback() { ReplicationFeedback::Pageserver(feedback) => Some(feedback.last_received_lsn), ReplicationFeedback::Standby(_) => None, }) @@ -111,9 +176,25 @@ impl WalSenders { } /// Returns total counter of pageserver feedbacks received and last feedback. - pub fn get_ps_feedback_stats(self: &Arc) -> (u64, PageserverFeedback) { + pub fn info_for_metrics(self: &Arc) -> WalSendersTimelineMetricValues { let shared = self.mutex.lock(); - (shared.ps_feedback_counter, shared.last_ps_feedback) + + let interpreted_wal_reader_tasks = shared + .slots + .iter() + .filter_map(|ss| match ss { + Some(WalSenderState::Interpreted(int)) => int.interpreted_wal_reader.as_ref(), + Some(WalSenderState::Vanilla(_)) => None, + None => None, + }) + .unique_by(|reader| Arc::as_ptr(reader)) + .count(); + + WalSendersTimelineMetricValues { + ps_feedback_counter: shared.ps_feedback_counter, + last_ps_feedback: shared.last_ps_feedback, + interpreted_wal_reader_tasks, + } } /// Get aggregated hot standby feedback (we send it to compute). @@ -124,7 +205,7 @@ impl WalSenders { /// Record new pageserver feedback, update aggregated values. fn record_ps_feedback(self: &Arc, id: WalSenderId, feedback: &PageserverFeedback) { let mut shared = self.mutex.lock(); - shared.get_slot_mut(id).feedback = ReplicationFeedback::Pageserver(*feedback); + *shared.get_slot_mut(id).get_mut_feedback() = ReplicationFeedback::Pageserver(*feedback); shared.last_ps_feedback = *feedback; shared.ps_feedback_counter += 1; drop(shared); @@ -143,10 +224,10 @@ impl WalSenders { "Record standby reply: ts={} apply_lsn={}", reply.reply_ts, reply.apply_lsn ); - match &mut slot.feedback { + match &mut slot.get_mut_feedback() { ReplicationFeedback::Standby(sf) => sf.reply = *reply, ReplicationFeedback::Pageserver(_) => { - slot.feedback = ReplicationFeedback::Standby(StandbyFeedback { + *slot.get_mut_feedback() = ReplicationFeedback::Standby(StandbyFeedback { reply: *reply, hs_feedback: HotStandbyFeedback::empty(), }) @@ -158,10 +239,10 @@ impl WalSenders { fn record_hs_feedback(self: &Arc, id: WalSenderId, feedback: &HotStandbyFeedback) { let mut shared = self.mutex.lock(); let slot = shared.get_slot_mut(id); - match &mut slot.feedback { + match &mut slot.get_mut_feedback() { ReplicationFeedback::Standby(sf) => sf.hs_feedback = *feedback, ReplicationFeedback::Pageserver(_) => { - slot.feedback = ReplicationFeedback::Standby(StandbyFeedback { + *slot.get_mut_feedback() = ReplicationFeedback::Standby(StandbyFeedback { reply: StandbyReply::empty(), hs_feedback: *feedback, }) @@ -175,7 +256,7 @@ impl WalSenders { pub fn get_ws_remote_consistent_lsn(self: &Arc, id: WalSenderId) -> Option { let shared = self.mutex.lock(); let slot = shared.get_slot(id); - match slot.feedback { + match slot.get_feedback() { ReplicationFeedback::Pageserver(feedback) => Some(feedback.remote_consistent_lsn), _ => None, } @@ -199,6 +280,47 @@ struct WalSendersShared { slots: Vec>, } +/// Safekeeper internal definitions of wal sender state +/// +/// As opposed to [`safekeeper_api::models::WalSenderState`] these struct may +/// include state that we don not wish to expose to the public api. +#[derive(Debug, Clone)] +pub(crate) enum WalSenderState { + Vanilla(VanillaWalSenderInternalState), + Interpreted(InterpretedWalSenderInternalState), +} + +type VanillaWalSenderInternalState = safekeeper_api::models::VanillaWalSenderState; + +#[derive(Debug, Clone)] +pub(crate) struct InterpretedWalSenderInternalState { + public_state: safekeeper_api::models::InterpretedWalSenderState, + interpreted_wal_reader: Option>, +} + +impl WalSenderState { + fn get_addr(&self) -> &SocketAddr { + match self { + WalSenderState::Vanilla(state) => &state.addr, + WalSenderState::Interpreted(state) => &state.public_state.addr, + } + } + + fn get_feedback(&self) -> &ReplicationFeedback { + match self { + WalSenderState::Vanilla(state) => &state.feedback, + WalSenderState::Interpreted(state) => &state.public_state.feedback, + } + } + + fn get_mut_feedback(&mut self) -> &mut ReplicationFeedback { + match self { + WalSenderState::Vanilla(state) => &mut state.feedback, + WalSenderState::Interpreted(state) => &mut state.public_state.feedback, + } + } +} + impl WalSendersShared { fn new() -> Self { WalSendersShared { @@ -225,7 +347,7 @@ impl WalSendersShared { let mut agg = HotStandbyFeedback::empty(); let mut reply_agg = StandbyReply::empty(); for ws_state in self.slots.iter().flatten() { - if let ReplicationFeedback::Standby(standby_feedback) = ws_state.feedback { + if let ReplicationFeedback::Standby(standby_feedback) = ws_state.get_feedback() { let hs_feedback = standby_feedback.hs_feedback; // doing Option math like op1.iter().chain(op2.iter()).min() // would be nicer, but we serialize/deserialize this struct @@ -317,7 +439,7 @@ impl SafekeeperPostgresHandler { /// Wrapper around handle_start_replication_guts handling result. Error is /// handled here while we're still in walsender ttid span; with API /// extension, this can probably be moved into postgres_backend. - pub async fn handle_start_replication( + pub async fn handle_start_replication( &mut self, pgb: &mut PostgresBackend, start_pos: Lsn, @@ -342,7 +464,7 @@ impl SafekeeperPostgresHandler { Ok(()) } - pub async fn handle_start_replication_guts( + pub async fn handle_start_replication_guts( &mut self, pgb: &mut PostgresBackend, start_pos: Lsn, @@ -352,12 +474,30 @@ impl SafekeeperPostgresHandler { let appname = self.appname.clone(); // Use a guard object to remove our entry from the timeline when we are done. - let ws_guard = Arc::new(tli.get_walsenders().register( - self.ttid, - *pgb.get_peer_addr(), - self.conn_id, - self.appname.clone(), - )); + let ws_guard = match self.protocol() { + PostgresClientProtocol::Vanilla => Arc::new(tli.get_walsenders().register( + WalSenderState::Vanilla(VanillaWalSenderInternalState { + ttid: self.ttid, + addr: *pgb.get_peer_addr(), + conn_id: self.conn_id, + appname: self.appname.clone(), + feedback: ReplicationFeedback::Pageserver(PageserverFeedback::empty()), + }), + )), + PostgresClientProtocol::Interpreted { .. } => Arc::new(tli.get_walsenders().register( + WalSenderState::Interpreted(InterpretedWalSenderInternalState { + public_state: safekeeper_api::models::InterpretedWalSenderState { + ttid: self.ttid, + shard: self.shard.unwrap(), + addr: *pgb.get_peer_addr(), + conn_id: self.conn_id, + appname: self.appname.clone(), + feedback: ReplicationFeedback::Pageserver(PageserverFeedback::empty()), + }, + interpreted_wal_reader: None, + }), + )), + }; // Walsender can operate in one of two modes which we select by // application_name: give only committed WAL (used by pageserver) or all @@ -403,7 +543,7 @@ impl SafekeeperPostgresHandler { pgb, // should succeed since we're already holding another guard tli: tli.wal_residence_guard().await?, - appname, + appname: appname.clone(), start_pos, end_pos, term, @@ -413,7 +553,7 @@ impl SafekeeperPostgresHandler { send_buf: vec![0u8; MAX_SEND_SIZE], }; - Either::Left(sender.run()) + FutureExt::boxed(sender.run()) } PostgresClientProtocol::Interpreted { format, @@ -421,27 +561,96 @@ impl SafekeeperPostgresHandler { } => { let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; let end_watch_view = end_watch.view(); - let wal_stream_builder = WalReaderStreamBuilder { - tli: tli.wal_residence_guard().await?, - start_pos, - end_pos, - term, - end_watch, - wal_sender_guard: ws_guard.clone(), - }; + let wal_residence_guard = tli.wal_residence_guard().await?; + let (tx, rx) = tokio::sync::mpsc::channel::(2); + let shard = self.shard.unwrap(); - let sender = InterpretedWalSender { - format, - compression, - pgb, - wal_stream_builder, - end_watch_view, - shard: self.shard.unwrap(), - pg_version, - appname, - }; + if self.conf.wal_reader_fanout && !shard.is_unsharded() { + let ws_id = ws_guard.id(); + ws_guard.walsenders().create_or_update_interpreted_reader( + ws_id, + start_pos, + self.conf.max_delta_for_fanout, + { + let tx = tx.clone(); + |reader| { + tracing::info!( + "Fanning out interpreted wal reader at {}", + start_pos + ); + reader + .fanout(shard, tx, start_pos) + .with_context(|| "Failed to fan out reader") + } + }, + || { + tracing::info!("Spawning interpreted wal reader at {}", start_pos); - Either::Right(sender.run()) + let wal_stream = StreamingWalReader::new( + wal_residence_guard, + term, + start_pos, + end_pos, + end_watch, + MAX_SEND_SIZE, + ); + + InterpretedWalReader::spawn( + wal_stream, start_pos, tx, shard, pg_version, &appname, + ) + }, + )?; + + let sender = InterpretedWalSender { + format, + compression, + appname, + tli: tli.wal_residence_guard().await?, + start_lsn: start_pos, + pgb, + end_watch_view, + wal_sender_guard: ws_guard.clone(), + rx, + }; + + FutureExt::boxed(sender.run()) + } else { + let wal_reader = StreamingWalReader::new( + wal_residence_guard, + term, + start_pos, + end_pos, + end_watch, + MAX_SEND_SIZE, + ); + + let reader = + InterpretedWalReader::new(wal_reader, start_pos, tx, shard, pg_version); + + let sender = InterpretedWalSender { + format, + compression, + appname: appname.clone(), + tli: tli.wal_residence_guard().await?, + start_lsn: start_pos, + pgb, + end_watch_view, + wal_sender_guard: ws_guard.clone(), + rx, + }; + + FutureExt::boxed(async move { + // Sender returns an Err on all code paths. + // If the sender finishes first, we will drop the reader future. + // If the reader finishes first, the sender will finish too since + // the wal sender has dropped. + let res = tokio::try_join!(sender.run(), reader.run(start_pos, &appname)); + match res.map(|_| ()) { + Ok(_) => unreachable!("sender finishes with Err by convention"), + err_res => err_res, + } + }) + } } }; @@ -470,7 +679,8 @@ impl SafekeeperPostgresHandler { .clone(); info!( "finished streaming to {}, feedback={:?}", - ws_state.addr, ws_state.feedback, + ws_state.get_addr(), + ws_state.get_feedback(), ); // Join pg backend back. @@ -578,6 +788,18 @@ impl WalSender<'_, IO> { /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? /// convenience. async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> { + let metric = WAL_READERS + .get_metric_with_label_values(&[ + "future", + self.appname.as_deref().unwrap_or("safekeeper"), + ]) + .unwrap(); + + metric.inc(); + scopeguard::defer! { + metric.dec(); + } + loop { // Wait for the next portion if it is not there yet, or just // update our end of WAL available for sending value, we @@ -813,7 +1035,7 @@ impl ReplyReader { #[cfg(test)] mod tests { use safekeeper_api::models::FullTransactionId; - use utils::id::{TenantId, TimelineId}; + use utils::id::{TenantId, TenantTimelineId, TimelineId}; use super::*; @@ -830,13 +1052,13 @@ mod tests { // add to wss specified feedback setting other fields to dummy values fn push_feedback(wss: &mut WalSendersShared, feedback: ReplicationFeedback) { - let walsender_state = WalSenderState { + let walsender_state = WalSenderState::Vanilla(VanillaWalSenderInternalState { ttid: mock_ttid(), addr: mock_addr(), conn_id: 1, appname: None, feedback, - }; + }); wss.slots.push(Some(walsender_state)) } diff --git a/safekeeper/src/test_utils.rs b/safekeeper/src/test_utils.rs index c40a8bae5a..4e851c5b3d 100644 --- a/safekeeper/src/test_utils.rs +++ b/safekeeper/src/test_utils.rs @@ -1,13 +1,19 @@ use std::sync::Arc; use crate::rate_limit::RateLimiter; -use crate::safekeeper::{ProposerAcceptorMessage, ProposerElected, SafeKeeper, TermHistory}; +use crate::receive_wal::WalAcceptor; +use crate::safekeeper::{ + AcceptorProposerMessage, AppendRequest, AppendRequestHeader, ProposerAcceptorMessage, + ProposerElected, SafeKeeper, TermHistory, +}; +use crate::send_wal::EndWatch; use crate::state::{TimelinePersistentState, TimelineState}; use crate::timeline::{get_timeline_dir, SharedState, StateSK, Timeline}; use crate::timelines_set::TimelinesSet; use crate::wal_backup::remote_timeline_path; -use crate::{control_file, wal_storage, SafeKeeperConf}; +use crate::{control_file, receive_wal, wal_storage, SafeKeeperConf}; use camino_tempfile::Utf8TempDir; +use postgres_ffi::v17::wal_generator::{LogicalMessageGenerator, WalGenerator}; use tokio::fs::create_dir_all; use utils::id::{NodeId, TenantTimelineId}; use utils::lsn::Lsn; @@ -107,4 +113,59 @@ impl Env { ); Ok(timeline) } + + // This will be dead code when building a non-benchmark target with the + // benchmarking feature enabled. + #[allow(dead_code)] + pub(crate) async fn write_wal( + tli: Arc, + start_lsn: Lsn, + msg_size: usize, + msg_count: usize, + ) -> anyhow::Result { + let (msg_tx, msg_rx) = tokio::sync::mpsc::channel(receive_wal::MSG_QUEUE_SIZE); + let (reply_tx, mut reply_rx) = tokio::sync::mpsc::channel(receive_wal::REPLY_QUEUE_SIZE); + + let end_watch = EndWatch::Commit(tli.get_commit_lsn_watch_rx()); + + WalAcceptor::spawn(tli.wal_residence_guard().await?, msg_rx, reply_tx, Some(0)); + + let prefix = c"p"; + let prefixlen = prefix.to_bytes_with_nul().len(); + assert!(msg_size >= prefixlen); + let message = vec![0; msg_size - prefixlen]; + + let walgen = + &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message), start_lsn); + for _ in 0..msg_count { + let (lsn, record) = walgen.next().unwrap(); + + let req = AppendRequest { + h: AppendRequestHeader { + term: 1, + term_start_lsn: start_lsn, + begin_lsn: lsn, + end_lsn: lsn + record.len() as u64, + commit_lsn: lsn, + truncate_lsn: Lsn(0), + proposer_uuid: [0; 16], + }, + wal_data: record, + }; + + let end_lsn = req.h.end_lsn; + + let msg = ProposerAcceptorMessage::AppendRequest(req); + msg_tx.send(msg).await?; + while let Some(reply) = reply_rx.recv().await { + if let AcceptorProposerMessage::AppendResponse(resp) = reply { + if resp.flush_lsn >= end_lsn { + break; + } + } + } + } + + Ok(end_watch) + } } diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 2882391074..5eb0bd7146 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -35,7 +35,7 @@ use crate::control_file; use crate::rate_limit::RateLimiter; use crate::receive_wal::WalReceivers; use crate::safekeeper::{AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, TermLsn}; -use crate::send_wal::WalSenders; +use crate::send_wal::{WalSenders, WalSendersTimelineMetricValues}; use crate::state::{EvictionState, TimelineMemState, TimelinePersistentState, TimelineState}; use crate::timeline_guard::ResidenceGuard; use crate::timeline_manager::{AtomicStatus, ManagerCtl}; @@ -712,16 +712,22 @@ impl Timeline { return None; } - let (ps_feedback_count, last_ps_feedback) = self.walsenders.get_ps_feedback_stats(); + let WalSendersTimelineMetricValues { + ps_feedback_counter, + last_ps_feedback, + interpreted_wal_reader_tasks, + } = self.walsenders.info_for_metrics(); + let state = self.read_shared_state().await; Some(FullTimelineInfo { ttid: self.ttid, - ps_feedback_count, + ps_feedback_count: ps_feedback_counter, last_ps_feedback, wal_backup_active: self.wal_backup_active.load(Ordering::Relaxed), timeline_is_active: self.broker_active.load(Ordering::Relaxed), num_computes: self.walreceivers.get_num() as u32, last_removed_segno: self.last_removed_segno.load(Ordering::Relaxed), + interpreted_wal_reader_tasks, epoch_start_lsn: state.sk.term_start_lsn(), mem_state: state.sk.state().inmem.clone(), persisted_state: TimelinePersistentState::clone(state.sk.state()), @@ -740,7 +746,7 @@ impl Timeline { debug_dump::Memory { is_cancelled: self.is_cancelled(), peers_info_len: state.peers_info.0.len(), - walsenders: self.walsenders.get_all(), + walsenders: self.walsenders.get_all_public(), wal_backup_active: self.wal_backup_active.load(Ordering::Relaxed), active: self.broker_active.load(Ordering::Relaxed), num_computes: self.walreceivers.get_num() as u32, diff --git a/safekeeper/src/wal_reader_stream.rs b/safekeeper/src/wal_reader_stream.rs index aea628c208..adac6067da 100644 --- a/safekeeper/src/wal_reader_stream.rs +++ b/safekeeper/src/wal_reader_stream.rs @@ -1,34 +1,16 @@ -use std::sync::Arc; - -use async_stream::try_stream; -use bytes::Bytes; -use futures::Stream; -use postgres_backend::CopyStreamHandlerEnd; -use safekeeper_api::Term; -use std::time::Duration; -use tokio::time::timeout; -use utils::lsn::Lsn; - -use crate::{ - send_wal::{EndWatch, WalSenderGuard}, - timeline::WalResidentTimeline, +use std::{ + pin::Pin, + task::{Context, Poll}, }; -pub(crate) struct WalReaderStreamBuilder { - pub(crate) tli: WalResidentTimeline, - pub(crate) start_pos: Lsn, - pub(crate) end_pos: Lsn, - pub(crate) term: Option, - pub(crate) end_watch: EndWatch, - pub(crate) wal_sender_guard: Arc, -} +use bytes::Bytes; +use futures::{stream::BoxStream, Stream, StreamExt}; +use utils::lsn::Lsn; -impl WalReaderStreamBuilder { - pub(crate) fn start_pos(&self) -> Lsn { - self.start_pos - } -} +use crate::{send_wal::EndWatch, timeline::WalResidentTimeline, wal_storage::WalReader}; +use safekeeper_api::Term; +#[derive(PartialEq, Eq, Debug)] pub(crate) struct WalBytes { /// Raw PG WAL pub(crate) wal: Bytes, @@ -44,106 +26,270 @@ pub(crate) struct WalBytes { pub(crate) available_wal_end_lsn: Lsn, } -impl WalReaderStreamBuilder { - /// Builds a stream of Postgres WAL starting from [`Self::start_pos`]. - /// The stream terminates when the receiver (pageserver) is fully caught up - /// and there's no active computes. - pub(crate) async fn build( - self, - buffer_size: usize, - ) -> anyhow::Result>> { - // TODO(vlad): The code below duplicates functionality from [`crate::send_wal`]. - // We can make the raw WAL sender use this stream too and remove the duplication. - let Self { - tli, - mut start_pos, - mut end_pos, - term, - mut end_watch, - wal_sender_guard, - } = self; - let mut wal_reader = tli.get_walreader(start_pos).await?; - let mut buffer = vec![0; buffer_size]; +struct PositionedWalReader { + start: Lsn, + end: Lsn, + reader: Option, +} - const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); +/// A streaming WAL reader wrapper which can be reset while running +pub(crate) struct StreamingWalReader { + stream: BoxStream<'static, WalOrReset>, + start_changed_tx: tokio::sync::watch::Sender, +} - Ok(try_stream! { - loop { - let have_something_to_send = end_pos > start_pos; +pub(crate) enum WalOrReset { + Wal(anyhow::Result), + Reset(Lsn), +} - if !have_something_to_send { - // wait for lsn - let res = timeout(POLL_STATE_TIMEOUT, end_watch.wait_for_lsn(start_pos, term)).await; - match res { - Ok(ok) => { - end_pos = ok?; - }, - Err(_) => { - if let EndWatch::Commit(_) = end_watch { - if let Some(remote_consistent_lsn) = wal_sender_guard - .walsenders() - .get_ws_remote_consistent_lsn(wal_sender_guard.id()) - { - if tli.should_walsender_stop(remote_consistent_lsn).await { - // Stop streaming if the receivers are caught up and - // there's no active compute. This causes the loop in - // [`crate::send_interpreted_wal::InterpretedWalSender::run`] - // to exit and terminate the WAL stream. - return; - } - } - } - - continue; - } - } - } - - - assert!( - end_pos > start_pos, - "nothing to send after waiting for WAL" - ); - - // try to send as much as available, capped by the buffer size - let mut chunk_end_pos = start_pos + buffer_size as u64; - // if we went behind available WAL, back off - if chunk_end_pos >= end_pos { - chunk_end_pos = end_pos; - } else { - // If sending not up to end pos, round down to page boundary to - // avoid breaking WAL record not at page boundary, as protocol - // demands. See walsender.c (XLogSendPhysical). - chunk_end_pos = chunk_end_pos - .checked_sub(chunk_end_pos.block_offset()) - .unwrap(); - } - let send_size = (chunk_end_pos.0 - start_pos.0) as usize; - let buffer = &mut buffer[..send_size]; - let send_size: usize; - { - // If uncommitted part is being pulled, check that the term is - // still the expected one. - let _term_guard = if let Some(t) = term { - Some(tli.acquire_term(t).await?) - } else { - None - }; - // Read WAL into buffer. send_size can be additionally capped to - // segment boundary here. - send_size = wal_reader.read(buffer).await? - }; - let wal = Bytes::copy_from_slice(&buffer[..send_size]); - - yield WalBytes { - wal, - wal_start_lsn: start_pos, - wal_end_lsn: start_pos + send_size as u64, - available_wal_end_lsn: end_pos - }; - - start_pos += send_size as u64; - } - }) +impl WalOrReset { + pub(crate) fn get_wal(self) -> Option> { + match self { + WalOrReset::Wal(wal) => Some(wal), + WalOrReset::Reset(_) => None, + } + } +} + +impl StreamingWalReader { + pub(crate) fn new( + tli: WalResidentTimeline, + term: Option, + start: Lsn, + end: Lsn, + end_watch: EndWatch, + buffer_size: usize, + ) -> Self { + let (start_changed_tx, start_changed_rx) = tokio::sync::watch::channel(start); + + let state = WalReaderStreamState { + tli, + wal_reader: PositionedWalReader { + start, + end, + reader: None, + }, + term, + end_watch, + buffer: vec![0; buffer_size], + buffer_size, + }; + + // When a change notification is received while polling the internal + // reader, stop polling the read future and service the change. + let stream = futures::stream::unfold( + (state, start_changed_rx), + |(mut state, mut rx)| async move { + let wal_or_reset = tokio::select! { + read_res = state.read() => { WalOrReset::Wal(read_res) }, + changed_res = rx.changed() => { + if changed_res.is_err() { + return None; + } + + let new_start_pos = rx.borrow_and_update(); + WalOrReset::Reset(*new_start_pos) + } + }; + + if let WalOrReset::Reset(lsn) = wal_or_reset { + state.wal_reader.start = lsn; + state.wal_reader.reader = None; + } + + Some((wal_or_reset, (state, rx))) + }, + ) + .boxed(); + + Self { + stream, + start_changed_tx, + } + } + + /// Reset the stream to a given position. + pub(crate) async fn reset(&mut self, start: Lsn) { + self.start_changed_tx.send(start).unwrap(); + while let Some(wal_or_reset) = self.stream.next().await { + match wal_or_reset { + WalOrReset::Reset(at) => { + // Stream confirmed the reset. + // There may only one ongoing reset at any given time, + // hence the assertion. + assert_eq!(at, start); + break; + } + WalOrReset::Wal(_) => { + // Ignore wal generated before reset was handled + } + } + } + } +} + +impl Stream for StreamingWalReader { + type Item = WalOrReset; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_next(cx) + } +} + +struct WalReaderStreamState { + tli: WalResidentTimeline, + wal_reader: PositionedWalReader, + term: Option, + end_watch: EndWatch, + buffer: Vec, + buffer_size: usize, +} + +impl WalReaderStreamState { + async fn read(&mut self) -> anyhow::Result { + // Create reader if needed + if self.wal_reader.reader.is_none() { + self.wal_reader.reader = Some(self.tli.get_walreader(self.wal_reader.start).await?); + } + + let have_something_to_send = self.wal_reader.end > self.wal_reader.start; + if !have_something_to_send { + tracing::debug!( + "Waiting for wal: start={}, end={}", + self.wal_reader.end, + self.wal_reader.start + ); + self.wal_reader.end = self + .end_watch + .wait_for_lsn(self.wal_reader.start, self.term) + .await?; + tracing::debug!( + "Done waiting for wal: start={}, end={}", + self.wal_reader.end, + self.wal_reader.start + ); + } + + assert!( + self.wal_reader.end > self.wal_reader.start, + "nothing to send after waiting for WAL" + ); + + // Calculate chunk size + let mut chunk_end_pos = self.wal_reader.start + self.buffer_size as u64; + if chunk_end_pos >= self.wal_reader.end { + chunk_end_pos = self.wal_reader.end; + } else { + chunk_end_pos = chunk_end_pos + .checked_sub(chunk_end_pos.block_offset()) + .unwrap(); + } + + let send_size = (chunk_end_pos.0 - self.wal_reader.start.0) as usize; + let buffer = &mut self.buffer[..send_size]; + + // Read WAL + let send_size = { + let _term_guard = if let Some(t) = self.term { + Some(self.tli.acquire_term(t).await?) + } else { + None + }; + self.wal_reader + .reader + .as_mut() + .unwrap() + .read(buffer) + .await? + }; + + let wal = Bytes::copy_from_slice(&buffer[..send_size]); + let result = WalBytes { + wal, + wal_start_lsn: self.wal_reader.start, + wal_end_lsn: self.wal_reader.start + send_size as u64, + available_wal_end_lsn: self.wal_reader.end, + }; + + self.wal_reader.start += send_size as u64; + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use futures::StreamExt; + use postgres_ffi::MAX_SEND_SIZE; + use utils::{ + id::{NodeId, TenantTimelineId}, + lsn::Lsn, + }; + + use crate::{test_utils::Env, wal_reader_stream::StreamingWalReader}; + + #[tokio::test] + async fn test_streaming_wal_reader_reset() { + let _ = env_logger::builder().is_test(true).try_init(); + + const SIZE: usize = 8 * 1024; + const MSG_COUNT: usize = 200; + + let start_lsn = Lsn::from_str("0/149FD18").unwrap(); + let env = Env::new(true).unwrap(); + let tli = env + .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn) + .await + .unwrap(); + + let resident_tli = tli.wal_residence_guard().await.unwrap(); + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) + .await + .unwrap(); + let end_pos = end_watch.get(); + + tracing::info!("Doing first round of reads ..."); + + let mut streaming_wal_reader = StreamingWalReader::new( + resident_tli, + None, + start_lsn, + end_pos, + end_watch, + MAX_SEND_SIZE, + ); + + let mut before_reset = Vec::new(); + while let Some(wor) = streaming_wal_reader.next().await { + let wal = wor.get_wal().unwrap().unwrap(); + let stop = wal.available_wal_end_lsn == wal.wal_end_lsn; + before_reset.push(wal); + + if stop { + break; + } + } + + tracing::info!("Resetting the WAL stream ..."); + + streaming_wal_reader.reset(start_lsn).await; + + tracing::info!("Doing second round of reads ..."); + + let mut after_reset = Vec::new(); + while let Some(wor) = streaming_wal_reader.next().await { + let wal = wor.get_wal().unwrap().unwrap(); + let stop = wal.available_wal_end_lsn == wal.wal_end_lsn; + after_reset.push(wal); + + if stop { + break; + } + } + + assert_eq!(before_reset, after_reset); } } diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs index a99de71a04..e0d593851e 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -178,6 +178,8 @@ pub fn run_server(os: NodeOs, disk: Arc) -> Result<()> { control_file_save_interval: Duration::from_secs(1), partial_backup_concurrency: 1, eviction_min_resident: Duration::ZERO, + wal_reader_fanout: false, + max_delta_for_fanout: None, }; let mut global = GlobalMap::new(disk, conf.clone())?;