postgres_ffi/waldecoder: introduce explicit enum State

Previously it was emulated with a combination of nullable fields.
This change should make the logic more readable.
This commit is contained in:
Egor Suvorov
2022-06-25 02:40:42 +03:00
committed by Arseny Sher
parent 07bb7a2afe
commit a7bf60631f

View File

@@ -18,19 +18,25 @@ use bytes::{Buf, BufMut, Bytes, BytesMut};
use crc32c::*;
use log::*;
use std::cmp::min;
use std::num::NonZeroU32;
use thiserror::Error;
use utils::lsn::Lsn;
enum State {
WaitingForRecord,
ReassemblingRecord {
recordbuf: BytesMut,
contlen: NonZeroU32,
},
SkippingEverything {
skip_until_lsn: Lsn,
},
}
pub struct WalStreamDecoder {
lsn: Lsn,
contlen: u32,
padlen: u32,
inputbuf: BytesMut,
/// buffer used to reassemble records that cross page boundaries.
recordbuf: BytesMut,
state: State,
}
#[derive(Error, Debug, Clone)]
@@ -48,12 +54,8 @@ impl WalStreamDecoder {
pub fn new(lsn: Lsn) -> WalStreamDecoder {
WalStreamDecoder {
lsn,
contlen: 0,
padlen: 0,
inputbuf: BytesMut::new(),
recordbuf: BytesMut::new(),
state: State::WaitingForRecord,
}
}
@@ -80,26 +82,39 @@ impl WalStreamDecoder {
hdr.xlp_pageaddr, self.lsn
));
}
if self.contlen == 0 {
if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD != 0 {
return Err(
"invalid xlog page header: unexpected XLP_FIRST_IS_CONTRECORD".into(),
);
match self.state {
State::WaitingForRecord => {
if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD != 0 {
return Err(
"invalid xlog page header: unexpected XLP_FIRST_IS_CONTRECORD".into(),
);
}
if hdr.xlp_rem_len != 0 {
return Err(format!(
"invalid xlog page header: xlp_rem_len={}, but it's not a contrecord",
hdr.xlp_rem_len
));
}
}
} else {
if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD == 0 {
return Err(
"invalid xlog page header: XLP_FIRST_IS_CONTRECORD expected, not found"
.into(),
);
State::ReassemblingRecord { contlen, .. } => {
if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD == 0 {
return Err(
"invalid xlog page header: XLP_FIRST_IS_CONTRECORD expected, not found"
.into(),
);
}
if hdr.xlp_rem_len != contlen.get() {
return Err(format!(
"invalid xlog page header: xlp_rem_len={}, expected {}",
hdr.xlp_rem_len,
contlen.get()
));
}
}
}
if hdr.xlp_rem_len != self.contlen {
return Err(format!(
"invalid xlog page header: xlp_rem_len={}, expected {}",
hdr.xlp_rem_len, self.contlen
));
}
State::SkippingEverything { .. } => {
panic!("Should not be validating page header in the SkippingEverything state");
}
};
Ok(())
};
validate_impl().map_err(|msg| WalDecodeError { msg, lsn: self.lsn })
@@ -114,115 +129,121 @@ impl WalStreamDecoder {
/// Err(WalDecodeError): an error occurred while decoding, meaning the input was invalid.
///
pub fn poll_decode(&mut self) -> Result<Option<(Lsn, Bytes)>, WalDecodeError> {
let recordbuf;
// Run state machine that validates page headers, and reassembles records
// that cross page boundaries.
loop {
// parse and verify page boundaries as we go
if self.padlen > 0 {
// We should first skip padding, as we may have to skip some page headers if we're processing the XLOG_SWITCH record.
if self.inputbuf.remaining() < self.padlen as usize {
return Ok(None);
}
// However, we may have to skip some page headers if we're processing the XLOG_SWITCH record or skipping padding for whatever reason.
match self.state {
State::WaitingForRecord | State::ReassemblingRecord { .. } => {
if self.lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE) == 0 {
// parse long header
// skip padding
self.inputbuf.advance(self.padlen as usize);
self.lsn += self.padlen as u64;
self.padlen = 0;
} else if self.lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE) == 0 {
// parse long header
if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_LONG_PHD {
return Ok(None);
}
if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_LONG_PHD {
return Ok(None);
}
let hdr = XLogLongPageHeaderData::from_bytes(&mut self.inputbuf).map_err(
|e| WalDecodeError {
msg: format!("long header deserialization failed {}", e),
lsn: self.lsn,
},
)?;
let hdr = XLogLongPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| {
WalDecodeError {
msg: format!("long header deserialization failed {}", e),
lsn: self.lsn,
self.validate_page_header(&hdr.std)?;
self.lsn += XLOG_SIZE_OF_XLOG_LONG_PHD as u64;
} else if self.lsn.block_offset() == 0 {
if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_SHORT_PHD {
return Ok(None);
}
let hdr =
XLogPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| {
WalDecodeError {
msg: format!("header deserialization failed {}", e),
lsn: self.lsn,
}
})?;
self.validate_page_header(&hdr)?;
self.lsn += XLOG_SIZE_OF_XLOG_SHORT_PHD as u64;
}
})?;
self.validate_page_header(&hdr.std)?;
self.lsn += XLOG_SIZE_OF_XLOG_LONG_PHD as u64;
continue;
} else if self.lsn.block_offset() == 0 {
if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_SHORT_PHD {
return Ok(None);
}
let hdr = XLogPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| {
WalDecodeError {
msg: format!("header deserialization failed {}", e),
lsn: self.lsn,
State::SkippingEverything { .. } => {}
}
match &mut self.state {
State::WaitingForRecord => {
// need to have at least the xl_tot_len field
if self.inputbuf.remaining() < 4 {
return Ok(None);
}
})?;
self.validate_page_header(&hdr)?;
self.lsn += XLOG_SIZE_OF_XLOG_SHORT_PHD as u64;
continue;
} else if self.contlen == 0 {
assert!(self.recordbuf.is_empty());
// need to have at least the xl_tot_len field
if self.inputbuf.remaining() < 4 {
return Ok(None);
// peek xl_tot_len at the beginning of the record.
// FIXME: assumes little-endian
let xl_tot_len = (&self.inputbuf[0..4]).get_u32_le();
if (xl_tot_len as usize) < XLOG_SIZE_OF_XLOG_RECORD {
return Err(WalDecodeError {
msg: format!("invalid xl_tot_len {}", xl_tot_len),
lsn: self.lsn,
});
}
// Fast path for the common case that the whole record fits on the page.
let pageleft = self.lsn.remaining_in_block() as u32;
if self.inputbuf.remaining() >= xl_tot_len as usize && xl_tot_len <= pageleft {
self.lsn += xl_tot_len as u64;
let recordbuf = self.inputbuf.copy_to_bytes(xl_tot_len as usize);
return Ok(Some(self.complete_record(recordbuf)?));
} else {
// Need to assemble the record from pieces. Remember the size of the
// record, and loop back. On next iteration, we will reach the 'else'
// branch below, and copy the part of the record that was on this page
// to 'recordbuf'. Subsequent iterations will skip page headers, and
// append the continuations from the next pages to 'recordbuf'.
self.state = State::ReassemblingRecord {
recordbuf: BytesMut::with_capacity(xl_tot_len as usize),
contlen: NonZeroU32::new(xl_tot_len).unwrap(),
}
}
}
State::ReassemblingRecord { recordbuf, contlen } => {
// we're continuing a record, possibly from previous page.
let pageleft = self.lsn.remaining_in_block() as u32;
// peek xl_tot_len at the beginning of the record.
// FIXME: assumes little-endian
let xl_tot_len = (&self.inputbuf[0..4]).get_u32_le();
if (xl_tot_len as usize) < XLOG_SIZE_OF_XLOG_RECORD {
return Err(WalDecodeError {
msg: format!("invalid xl_tot_len {}", xl_tot_len),
lsn: self.lsn,
});
// read the rest of the record, or as much as fits on this page.
let n = min(contlen.get(), pageleft) as usize;
if self.inputbuf.remaining() < n {
return Ok(None);
}
recordbuf.put(self.inputbuf.split_to(n));
self.lsn += n as u64;
*contlen = match NonZeroU32::new(contlen.get() - n as u32) {
Some(x) => x,
None => {
// The record is now complete.
let recordbuf = std::mem::replace(recordbuf, BytesMut::new()).freeze();
return Ok(Some(self.complete_record(recordbuf)?));
}
}
}
// Fast path for the common case that the whole record fits on the page.
let pageleft = self.lsn.remaining_in_block() as u32;
if self.inputbuf.remaining() >= xl_tot_len as usize && xl_tot_len <= pageleft {
// Take the record from the 'inputbuf', and validate it.
recordbuf = self.inputbuf.copy_to_bytes(xl_tot_len as usize);
self.lsn += xl_tot_len as u64;
break;
} else {
// Need to assemble the record from pieces. Remember the size of the
// record, and loop back. On next iteration, we will reach the 'else'
// branch below, and copy the part of the record that was on this page
// to 'recordbuf'. Subsequent iterations will skip page headers, and
// append the continuations from the next pages to 'recordbuf'.
self.recordbuf.reserve(xl_tot_len as usize);
self.contlen = xl_tot_len;
continue;
State::SkippingEverything { skip_until_lsn } => {
assert!(*skip_until_lsn >= self.lsn);
let n = skip_until_lsn.0 - self.lsn.0;
if self.inputbuf.remaining() < n as usize {
return Ok(None);
}
self.inputbuf.advance(n as usize);
self.lsn += n;
self.state = State::WaitingForRecord;
}
} else {
// we're continuing a record, possibly from previous page.
let pageleft = self.lsn.remaining_in_block() as u32;
// read the rest of the record, or as much as fits on this page.
let n = min(self.contlen, pageleft) as usize;
if self.inputbuf.remaining() < n {
return Ok(None);
}
self.recordbuf.put(self.inputbuf.split_to(n));
self.lsn += n as u64;
self.contlen -= n as u32;
if self.contlen == 0 {
// The record is now complete.
recordbuf = std::mem::replace(&mut self.recordbuf, BytesMut::new()).freeze();
break;
}
continue;
}
}
}
fn complete_record(&mut self, recordbuf: Bytes) -> Result<(Lsn, Bytes), WalDecodeError> {
// We now have a record in the 'recordbuf' local variable.
let xlogrec =
XLogRecord::from_slice(&recordbuf[0..XLOG_SIZE_OF_XLOG_RECORD]).map_err(|e| {
@@ -244,18 +265,20 @@ impl WalStreamDecoder {
// XLOG_SWITCH records are special. If we see one, we need to skip
// to the next WAL segment.
if xlogrec.is_xlog_switch_record() {
let next_lsn = if xlogrec.is_xlog_switch_record() {
trace!("saw xlog switch record at {}", self.lsn);
self.padlen = self.lsn.calc_padding(pg_constants::WAL_SEGMENT_SIZE as u64) as u32;
self.lsn + self.lsn.calc_padding(pg_constants::WAL_SEGMENT_SIZE as u64)
} else {
// Pad to an 8-byte boundary
self.padlen = self.lsn.calc_padding(8u32) as u32;
}
self.lsn.align()
};
self.state = State::SkippingEverything {
skip_until_lsn: next_lsn,
};
// We should return LSN of the next record, not the last byte of this record or
// the byte immediately after. Note that this handles both XLOG_SWITCH and usual
// records, the former "spans" until the next WAL segment (see test_xlog_switch).
let result = (self.lsn + self.padlen as u64, recordbuf);
Ok(Some(result))
Ok((next_lsn, recordbuf))
}
}