diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index ea09ce364d..b57cc8001d 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -120,6 +120,20 @@ pub enum InterpretedWalReaderError { WalStreamClosed, } +enum CurrentPositionUpdate { + Reset(Lsn), + NotReset(Lsn), +} + +impl CurrentPositionUpdate { + fn current_position(&self) -> Lsn { + match self { + CurrentPositionUpdate::Reset(lsn) => *lsn, + CurrentPositionUpdate::NotReset(lsn) => *lsn, + } + } +} + impl InterpretedWalReaderState { fn current_position(&self) -> Option { match self { @@ -129,6 +143,26 @@ impl InterpretedWalReaderState { InterpretedWalReaderState::Done => None, } } + + // Reset the current position of the WAL reader if the requested starting position + // of the new shard is smaller than the current value. + fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate { + match self { + InterpretedWalReaderState::Running { + current_position, .. + } => { + if new_shard_start_pos < *current_position { + *current_position = new_shard_start_pos; + CurrentPositionUpdate::Reset(*current_position) + } else { + CurrentPositionUpdate::NotReset(*current_position) + } + } + InterpretedWalReaderState::Done => { + panic!("maybe_reset called on finished reader") + } + } + } } pub(crate) struct AttachShardNotification { @@ -410,15 +444,24 @@ impl InterpretedWalReader { }; senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos}); - let current_pos = self.state.read().unwrap().current_position().unwrap(); - if start_pos < current_pos { - self.wal_stream.reset(start_pos).await; - wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version); - } + + // If the shard is subscribing below the current position the we need + // to update the cursor that tracks where we are at in the WAL + // ([`Self::state`]) and reset the WAL stream itself + // (`[Self::wal_stream`]). This must be done atomically from the POV of + // anything outside the select statement. + let position_reset = self.state.write().unwrap().maybe_reset(start_pos); + match position_reset { + CurrentPositionUpdate::Reset(to) => { + self.wal_stream.reset(to).await; + wal_decoder = WalStreamDecoder::new(to, self.pg_version); + }, + CurrentPositionUpdate::NotReset(_) => {} + }; tracing::info!( "Added shard sender {} with start_pos={} current_pos={}", - ShardSenderId::new(shard_id, new_sender_id), start_pos, current_pos + ShardSenderId::new(shard_id, new_sender_id), start_pos, position_reset.current_position() ); } } @@ -584,7 +627,7 @@ mod tests { .unwrap(); let resident_tli = tli.wal_residence_guard().await.unwrap(); - let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None) .await .unwrap(); let end_pos = end_watch.get(); @@ -715,7 +758,6 @@ mod tests { const MSG_COUNT: usize = 200; const PG_VERSION: u32 = 17; const SHARD_COUNT: u8 = 2; - const ATTACHED_SHARDS: u8 = 4; let start_lsn = Lsn::from_str("0/149FD18").unwrap(); let env = Env::new(true).unwrap(); @@ -725,9 +767,11 @@ mod tests { .unwrap(); let resident_tli = tli.wal_residence_guard().await.unwrap(); - let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) - .await - .unwrap(); + let mut next_record_lsns = Vec::default(); + let end_watch = + Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, Some(&mut next_record_lsns)) + .await + .unwrap(); let end_pos = end_watch.get(); let streaming_wal_reader = StreamingWalReader::new( @@ -746,38 +790,71 @@ mod tests { ) .unwrap(); - let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); - let mut batch_receivers = vec![rx]; + struct Sender { + tx: Option>, + rx: tokio::sync::mpsc::Receiver, + shard: ShardIdentity, + start_lsn: Lsn, + received_next_record_lsns: Vec, + } + impl Sender { + fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); + Self { + tx: Some(tx), + rx, + shard, + start_lsn, + received_next_record_lsns: Vec::default(), + } + } + } + + assert!(next_record_lsns.len() > 7); + let start_lsns = vec![ + next_record_lsns[5], + next_record_lsns[1], + next_record_lsns[3], + ]; + let mut senders = start_lsns + .into_iter() + .map(|lsn| Sender::new(lsn, shard_0)) + .collect::>(); + + let first_sender = senders.first_mut().unwrap(); let handle = InterpretedWalReader::spawn( streaming_wal_reader, - start_lsn, - tx, - shard_0, + first_sender.start_lsn, + first_sender.tx.take().unwrap(), + first_sender.shard, PG_VERSION, &Some("pageserver".to_string()), ); - for _ in 0..(ATTACHED_SHARDS - 1) { - let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); - handle.fanout(shard_0, tx, start_lsn).unwrap(); - batch_receivers.push(rx); + for sender in senders.iter_mut().skip(1) { + handle + .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn) + .unwrap(); } - loop { - let batch = batch_receivers.first_mut().unwrap().recv().await.unwrap(); - for rx in batch_receivers.iter_mut().skip(1) { - let other_batch = rx.recv().await.unwrap(); - - assert_eq!(batch.wal_end_lsn, other_batch.wal_end_lsn); - assert_eq!( - batch.available_wal_end_lsn, - other_batch.available_wal_end_lsn + for sender in senders.iter_mut() { + loop { + let batch = sender.rx.recv().await.unwrap(); + tracing::info!( + "Sender with start_lsn={} received batch ending at {} with {} records", + sender.start_lsn, + batch.wal_end_lsn, + batch.records.records.len() ); - } - if batch.wal_end_lsn == batch.available_wal_end_lsn { - break; + for rec in batch.records.records { + sender.received_next_record_lsns.push(rec.next_record_lsn); + } + + if batch.wal_end_lsn == batch.available_wal_end_lsn { + break; + } } } @@ -792,5 +869,20 @@ mod tests { } assert!(done); + + for sender in senders { + tracing::info!( + "Validating records received by sender with start_lsn={}", + sender.start_lsn + ); + + assert!(sender.received_next_record_lsns.is_sorted()); + let expected = next_record_lsns + .iter() + .filter(|lsn| **lsn > sender.start_lsn) + .copied() + .collect::>(); + assert_eq!(sender.received_next_record_lsns, expected); + } } } diff --git a/safekeeper/src/test_utils.rs b/safekeeper/src/test_utils.rs index 4e851c5b3d..79ceddd366 100644 --- a/safekeeper/src/test_utils.rs +++ b/safekeeper/src/test_utils.rs @@ -122,6 +122,7 @@ impl Env { start_lsn: Lsn, msg_size: usize, msg_count: usize, + mut next_record_lsns: Option<&mut Vec>, ) -> anyhow::Result { let (msg_tx, msg_rx) = tokio::sync::mpsc::channel(receive_wal::MSG_QUEUE_SIZE); let (reply_tx, mut reply_rx) = tokio::sync::mpsc::channel(receive_wal::REPLY_QUEUE_SIZE); @@ -130,7 +131,7 @@ impl Env { WalAcceptor::spawn(tli.wal_residence_guard().await?, msg_rx, reply_tx, Some(0)); - let prefix = c"p"; + let prefix = c"neon-file:"; let prefixlen = prefix.to_bytes_with_nul().len(); assert!(msg_size >= prefixlen); let message = vec![0; msg_size - prefixlen]; @@ -139,6 +140,9 @@ impl Env { &mut WalGenerator::new(LogicalMessageGenerator::new(prefix, &message), start_lsn); for _ in 0..msg_count { let (lsn, record) = walgen.next().unwrap(); + if let Some(ref mut lsns) = next_record_lsns { + lsns.push(lsn); + } let req = AppendRequest { h: AppendRequestHeader { diff --git a/safekeeper/src/wal_reader_stream.rs b/safekeeper/src/wal_reader_stream.rs index adac6067da..a0dd571a34 100644 --- a/safekeeper/src/wal_reader_stream.rs +++ b/safekeeper/src/wal_reader_stream.rs @@ -246,7 +246,7 @@ mod tests { .unwrap(); let resident_tli = tli.wal_residence_guard().await.unwrap(); - let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT) + let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None) .await .unwrap(); let end_pos = end_watch.get();