diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index fb06339604..0662bb9518 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -100,7 +100,12 @@ struct ShardSenderState { /// State of [`InterpretedWalReader`] visible outside of the task running it. #[derive(Debug)] pub(crate) enum InterpretedWalReaderState { - Running { current_position: Lsn }, + Running { + current_position: Lsn, + /// Tracks the start of the PG WAL LSN from which the current batch of + /// interpreted records originated. + current_batch_wal_start: Option, + }, Done, } @@ -122,14 +127,21 @@ pub enum InterpretedWalReaderError { } enum CurrentPositionUpdate { - Reset(Lsn), + Reset { from: Lsn, to: Lsn }, NotReset(Lsn), } impl CurrentPositionUpdate { fn current_position(&self) -> Lsn { match self { - CurrentPositionUpdate::Reset(lsn) => *lsn, + CurrentPositionUpdate::Reset { from: _, to } => *to, + CurrentPositionUpdate::NotReset(lsn) => *lsn, + } + } + + fn previous_position(&self) -> Lsn { + match self { + CurrentPositionUpdate::Reset { from, to: _ } => *from, CurrentPositionUpdate::NotReset(lsn) => *lsn, } } @@ -145,16 +157,33 @@ impl InterpretedWalReaderState { } } + #[cfg(test)] + fn current_batch_wal_start(&self) -> Option { + match self { + InterpretedWalReaderState::Running { + current_batch_wal_start, + .. + } => *current_batch_wal_start, + InterpretedWalReaderState::Done => None, + } + } + // Reset the current position of the WAL reader if the requested starting position // of the new shard is smaller than the current value. fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate { match self { InterpretedWalReaderState::Running { - current_position, .. + current_position, + current_batch_wal_start, } => { if new_shard_start_pos < *current_position { + let from = *current_position; *current_position = new_shard_start_pos; - CurrentPositionUpdate::Reset(*current_position) + *current_batch_wal_start = None; + CurrentPositionUpdate::Reset { + from, + to: *current_position, + } } else { CurrentPositionUpdate::NotReset(*current_position) } @@ -164,6 +193,47 @@ impl InterpretedWalReaderState { } } } + + fn update_current_batch_wal_start(&mut self, lsn: Lsn) { + match self { + InterpretedWalReaderState::Running { + current_batch_wal_start, + .. + } => { + if current_batch_wal_start.is_none() { + *current_batch_wal_start = Some(lsn); + } + } + InterpretedWalReaderState::Done => { + panic!("update_current_batch_wal_start called on finished reader") + } + } + } + + fn take_current_batch_wal_start(&mut self) -> Lsn { + match self { + InterpretedWalReaderState::Running { + current_batch_wal_start, + .. + } => current_batch_wal_start.take().unwrap(), + InterpretedWalReaderState::Done => { + panic!("take_current_batch_wal_start called on finished reader") + } + } + } + + fn update_current_position(&mut self, lsn: Lsn) { + match self { + InterpretedWalReaderState::Running { + current_position, .. + } => { + *current_position = lsn; + } + InterpretedWalReaderState::Done => { + panic!("update_current_position called on finished reader") + } + } + } } pub(crate) struct AttachShardNotification { @@ -184,6 +254,7 @@ impl InterpretedWalReader { ) -> InterpretedWalReaderHandle { let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running { current_position: start_pos, + current_batch_wal_start: None, })); let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -237,9 +308,13 @@ impl InterpretedWalReader { tx: tokio::sync::mpsc::Sender, shard: ShardIdentity, pg_version: u32, + shard_notification_rx: Option< + tokio::sync::mpsc::UnboundedReceiver, + >, ) -> InterpretedWalReader { let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running { current_position: start_pos, + current_batch_wal_start: None, })); InterpretedWalReader { @@ -252,7 +327,7 @@ impl InterpretedWalReader { next_record_lsn: start_pos, }], )]), - shard_notification_rx: None, + shard_notification_rx, state: state.clone(), pg_version, } @@ -295,10 +370,6 @@ impl InterpretedWalReader { let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version); - // Tracks the start of the PG WAL LSN from which the current batch of - // interpreted records originated. - let mut current_batch_wal_start_lsn: Option = None; - loop { tokio::select! { // Main branch for reading WAL and forwarding it @@ -319,11 +390,7 @@ impl InterpretedWalReader { } }; - // We will already have a value if the previous chunks of WAL - // did not decode into anything useful. - if current_batch_wal_start_lsn.is_none() { - current_batch_wal_start_lsn = Some(wal_start_lsn); - } + self.state.write().unwrap().update_current_batch_wal_start(wal_start_lsn); wal_decoder.feed_bytes(&wal); @@ -380,16 +447,11 @@ impl InterpretedWalReader { // Update the current position such that new receivers can decide // whether to attach to us or spawn a new WAL reader. - match &mut *self.state.write().unwrap() { - InterpretedWalReaderState::Running { current_position, .. } => { - *current_position = max_next_record_lsn; - }, - InterpretedWalReaderState::Done => { - unreachable!() - } - } - - let batch_wal_start_lsn = current_batch_wal_start_lsn.take().unwrap(); + let batch_wal_start_lsn = { + let mut guard = self.state.write().unwrap(); + guard.update_current_position(max_next_record_lsn); + guard.take_current_batch_wal_start() + }; // Send interpreted records downstream. Anything that has already been seen // by a shard is filtered out. @@ -480,7 +542,7 @@ impl InterpretedWalReader { // anything outside the select statement. let position_reset = self.state.write().unwrap().maybe_reset(start_pos); match position_reset { - CurrentPositionUpdate::Reset(to) => { + CurrentPositionUpdate::Reset { from: _, to } => { self.wal_stream.reset(to).await; wal_decoder = WalStreamDecoder::new(to, self.pg_version); }, @@ -488,14 +550,22 @@ impl InterpretedWalReader { }; tracing::info!( - "Added shard sender {} with start_pos={} current_pos={}", - ShardSenderId::new(shard_id, new_sender_id), start_pos, position_reset.current_position() + "Added shard sender {} with start_pos={} previous_pos={} current_pos={}", + ShardSenderId::new(shard_id, new_sender_id), + start_pos, + position_reset.previous_position(), + position_reset.current_position(), ); } } } } } + + #[cfg(test)] + fn state(&self) -> Arc> { + self.state.clone() + } } impl InterpretedWalReaderHandle { @@ -633,7 +703,7 @@ mod tests { }; use crate::{ - send_interpreted_wal::{Batch, InterpretedWalReader}, + send_interpreted_wal::{AttachShardNotification, Batch, InterpretedWalReader}, test_utils::Env, wal_reader_stream::StreamingWalReader, }; @@ -913,4 +983,123 @@ mod tests { assert_eq!(sender.received_next_record_lsns, expected); } } + + #[tokio::test] + async fn test_batch_start_tracking_on_reset() { + // When the WAL stream is reset to an older LSN, + // the current batch start LSN should be invalidated. + // This test constructs such a scenario: + // 1. Shard 0 is reading somewhere ahead + // 2. Reader reads some WAL, but does not decode a full record (partial read) + // 3. Shard 1 attaches to the reader and resets it to an older LSN + // 4. Shard 1 should get the correct batch WAL start LSN + let _ = env_logger::builder().is_test(true).try_init(); + + const SIZE: usize = 64 * 1024; + const MSG_COUNT: usize = 10; + const PG_VERSION: u32 = 17; + const SHARD_COUNT: u8 = 2; + const WAL_READER_BATCH_SIZE: usize = 8192; + + let start_lsn = Lsn::from_str("0/149FD18").unwrap(); + let shard_0_start_lsn = Lsn::from_str("0/14AFE10").unwrap(); + let env = Env::new(true).unwrap(); + let tli = env + .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn) + .await + .unwrap(); + + let resident_tli = tli.wal_residence_guard().await.unwrap(); + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None) + .await + .unwrap(); + let end_pos = end_watch.get(); + + let streaming_wal_reader = StreamingWalReader::new( + resident_tli, + None, + shard_0_start_lsn, + end_pos, + end_watch, + WAL_READER_BATCH_SIZE, + ); + + let shard_0 = ShardIdentity::new( + ShardNumber(0), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + + let shard_1 = ShardIdentity::new( + ShardNumber(1), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + + let mut shards = HashMap::new(); + + for shard_number in 0..SHARD_COUNT { + let shard_id = ShardIdentity::new( + ShardNumber(shard_number), + ShardCount(SHARD_COUNT), + ShardStripeSize::default(), + ) + .unwrap(); + let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); + shards.insert(shard_id, (Some(tx), Some(rx))); + } + + let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap(); + + let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel(); + + let reader = InterpretedWalReader::new( + streaming_wal_reader, + shard_0_start_lsn, + shard_0_tx, + shard_0, + PG_VERSION, + Some(shard_notification_rx), + ); + + let reader_state = reader.state(); + let mut reader_fut = std::pin::pin!(reader.run(start_lsn, &None)); + loop { + let poll = futures::poll!(reader_fut.as_mut()); + assert!(poll.is_pending()); + + let guard = reader_state.read().unwrap(); + if guard.current_batch_wal_start().is_some() { + break; + } + } + + shard_notification_tx + .send(AttachShardNotification { + shard_id: shard_1, + sender: shards.get_mut(&shard_1).unwrap().0.take().unwrap(), + start_pos: start_lsn, + }) + .unwrap(); + + let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap(); + loop { + let poll = futures::poll!(reader_fut.as_mut()); + assert!(poll.is_pending()); + + let try_recv_res = shard_1_rx.try_recv(); + match try_recv_res { + Ok(batch) => { + assert_eq!(batch.records.raw_wal_start_lsn.unwrap(), start_lsn); + break; + } + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {} + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { + unreachable!(); + } + } + } + } } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 4a4a74a0fd..72b1fd9fc3 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -624,8 +624,9 @@ impl SafekeeperPostgresHandler { MAX_SEND_SIZE, ); - let reader = - InterpretedWalReader::new(wal_reader, start_pos, tx, shard, pg_version); + let reader = InterpretedWalReader::new( + wal_reader, start_pos, tx, shard, pg_version, None, + ); let sender = InterpretedWalSender { format,