page_service - use anyhow for error handling

This commit is contained in:
Patrick Insinger
2021-05-11 12:36:33 -04:00
committed by Patrick Insinger
parent d5bfe84d9e
commit d8e509d29e

View File

@@ -10,6 +10,7 @@
// *callmemaybe <zenith timelineid> $url* -- ask pageserver to start walreceiver on $url
//
use anyhow::{anyhow, bail};
use byteorder::{ReadBytesExt, WriteBytesExt, BE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use log::*;
@@ -30,8 +31,6 @@ use crate::walreceiver;
use crate::PageServerConf;
use crate::ZTimelineId;
type Result<T> = std::result::Result<T, io::Error>;
#[derive(Debug)]
enum FeMessage {
StartupMessage(FeStartupMessage),
@@ -112,7 +111,7 @@ enum StartupRequestCode {
}
impl FeStartupMessage {
pub fn read(stream: &mut dyn std::io::Read) -> Result<Option<FeMessage>> {
pub fn read(stream: &mut dyn std::io::Read) -> anyhow::Result<Option<FeMessage>> {
const MAX_STARTUP_PACKET_LENGTH: u32 = 10000;
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
@@ -124,19 +123,12 @@ impl FeStartupMessage {
// in the log if the client opens connection but closes it immediately.
let len = match stream.read_u32::<BE>() {
Ok(len) => len,
Err(err) => {
if err.kind() == std::io::ErrorKind::UnexpectedEof {
return Ok(None);
} else {
return Err(err);
}
}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
if len < 4 || len as u32 > MAX_STARTUP_PACKET_LENGTH {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length",
));
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
let bodylen = len - 4;
@@ -181,15 +173,12 @@ struct FeParseMessage {
query_string: Bytes,
}
fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let mut result = BytesMut::new();
loop {
if !buf.has_remaining() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"no null-terminator in string",
));
bail!("no null-terminator in string");
}
let byte = buf.get_u8();
@@ -203,7 +192,7 @@ fn read_null_terminated(buf: &mut Bytes) -> Result<Bytes> {
}
impl FeParseMessage {
pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _pstmt_name = read_null_terminated(&mut buf)?;
let query_string = read_null_terminated(&mut buf)?;
let nparams = buf.get_i16();
@@ -222,10 +211,7 @@ impl FeParseMessage {
*/
if nparams != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"query params not implemented",
));
bail!("query params not implemented");
}
Ok(FeMessage::Parse(FeParseMessage { query_string }))
@@ -239,7 +225,7 @@ struct FeDescribeMessage {
}
impl FeDescribeMessage {
pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_null_terminated(&mut buf)?;
@@ -254,10 +240,7 @@ impl FeDescribeMessage {
*/
if kind != b'S' {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"only prepared statmement Describe is implemented",
));
bail!("only prepared statmement Describe is implemented");
}
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
@@ -272,22 +255,16 @@ struct FeExecuteMessage {
}
impl FeExecuteMessage {
pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named portals not implemented",
));
bail!("named portals not implemented");
}
if maxrows != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"row limit in Execute message not supported",
));
bail!("row limit in Execute message not supported");
}
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
@@ -299,15 +276,12 @@ impl FeExecuteMessage {
struct FeBindMessage {}
impl FeBindMessage {
pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let _pstmt_name = read_null_terminated(&mut buf)?;
if !portal_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named portals not implemented",
));
bail!("named portals not implemented");
}
// FIXME: see FeParseMessage::parse
@@ -329,7 +303,7 @@ impl FeBindMessage {
struct FeCloseMessage {}
impl FeCloseMessage {
pub fn parse(mut buf: Bytes) -> Result<FeMessage> {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_null_terminated(&mut buf)?;
@@ -340,28 +314,20 @@ impl FeCloseMessage {
}
impl FeMessage {
pub fn read(stream: &mut dyn Read) -> Result<Option<FeMessage>> {
pub fn read(stream: &mut dyn Read) -> anyhow::Result<Option<FeMessage>> {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match stream.read_u8() {
Ok(b) => b,
Err(err) => {
if err.kind() == std::io::ErrorKind::UnexpectedEof {
return Ok(None);
} else {
return Err(err);
}
}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = stream.read_u32::<BE>()?;
// The message length includes itself, so it better be at least 4
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
));
bail!("invalid message length: parsing u32");
}
let bodylen = len - 4;
@@ -398,16 +364,14 @@ impl FeMessage {
0 => Ok(Some(FeMessage::ZenithExistsRequest(zreq))),
1 => Ok(Some(FeMessage::ZenithNblocksRequest(zreq))),
2 => Ok(Some(FeMessage::ZenithReadRequest(zreq))),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown smgr message tag: {},'{:?}'", smgr_tag, body),
_ => Err(anyhow!(
"unknown smgr message tag: {},'{:?}'",
smgr_tag,
body
)),
}
}
tag => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown message tag: {},'{:?}'", tag, body),
)),
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
}
}
}
@@ -442,7 +406,6 @@ pub fn thread_main(conf: &PageServerConf) {
struct Connection {
stream_in: BufReader<TcpStream>,
stream: BufWriter<TcpStream>,
buffer: BytesMut,
init_done: bool,
conf: PageServerConf,
}
@@ -452,7 +415,6 @@ impl Connection {
Connection {
stream_in: BufReader::new(socket.try_clone().unwrap()),
stream: BufWriter::new(socket),
buffer: BytesMut::with_capacity(10 * 1024),
init_done: false,
conf,
}
@@ -461,7 +423,7 @@ impl Connection {
//
// Read full message or return None if connection is closed
//
fn read_message(&mut self) -> Result<Option<FeMessage>> {
fn read_message(&mut self) -> anyhow::Result<Option<FeMessage>> {
if !self.init_done {
FeStartupMessage::read(&mut self.stream_in)
} else {
@@ -596,7 +558,7 @@ impl Connection {
self.stream.flush()
}
fn run(&mut self) -> Result<()> {
fn run(&mut self) -> anyhow::Result<()> {
let mut unnamed_query_string = Bytes::new();
loop {
let msg = self.read_message()?;
@@ -650,8 +612,7 @@ impl Connection {
break;
}
x => {
error!("unexpected message type : {:?}", x);
return Err(io::Error::new(io::ErrorKind::Other, "unexpected message"));
bail!("unexpected message type : {:?}", x);
}
}
}
@@ -659,7 +620,7 @@ impl Connection {
Ok(())
}
fn process_query(&mut self, query_string: Bytes) -> Result<()> {
fn process_query(&mut self, query_string: Bytes) -> anyhow::Result<()> {
debug!("process query {:?}", query_string);
// remove null terminator, if any
@@ -669,78 +630,78 @@ impl Connection {
}
if query_string.starts_with(b"controlfile") {
self.handle_controlfile()
self.handle_controlfile()?;
} else if query_string.starts_with(b"pagestream ") {
let (_l, r) = query_string.split_at("pagestream ".len());
let timelineid_str = String::from_utf8(r.to_vec()).unwrap();
let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap();
let timelineid_str = String::from_utf8(r.to_vec())?;
let timelineid = ZTimelineId::from_str(&timelineid_str)?;
self.handle_pagerequests(timelineid)
self.handle_pagerequests(timelineid)?;
} else if query_string.starts_with(b"basebackup ") {
let (_l, r) = query_string.split_at("basebackup ".len());
let r = r.to_vec();
let timelineid_str = String::from(String::from_utf8(r).unwrap().trim_end());
let timelineid_str = String::from(String::from_utf8(r)?.trim_end());
info!("got basebackup command: \"{}\"", timelineid_str);
let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap();
let timelineid = ZTimelineId::from_str(&timelineid_str)?;
// Check that the timeline exists
self.handle_basebackup_request(timelineid)?;
self.write_message_noflush(&BeMessage::CommandComplete)?;
self.write_message(&BeMessage::ReadyForQuery)
self.write_message(&BeMessage::ReadyForQuery)?;
} else if query_string.starts_with(b"callmemaybe ") {
let query_str = String::from_utf8(query_string.to_vec()).unwrap();
let query_str = String::from_utf8(query_string.to_vec())?;
// callmemaybe <zenith timelineid as hex string> <connstr>
// TODO lazy static
let re = Regex::new(r"^callmemaybe ([[:xdigit:]]+) (.*)$").unwrap();
let caps = re.captures(&query_str);
let caps = caps.unwrap();
let caps = re
.captures(&query_str)
.ok_or_else(|| anyhow!("invalid callmemaybe: '{}'", query_str))?;
let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str()).unwrap();
let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str())?;
let connstr: String = String::from(caps.get(2).unwrap().as_str());
// Check that the timeline exists
let repository = page_cache::get_repository();
let timeline = repository.get_or_restore_timeline(timelineid);
if timeline.is_err() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid)));
if repository.get_or_restore_timeline(timelineid).is_err() {
bail!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid);
}
walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr);
self.write_message_noflush(&BeMessage::CommandComplete)?;
self.write_message(&BeMessage::ReadyForQuery)
self.write_message(&BeMessage::ReadyForQuery)?;
} else if query_string.starts_with(b"status") {
self.write_message_noflush(&BeMessage::RowDescription)?;
self.write_message_noflush(&BeMessage::DataRow)?;
self.write_message_noflush(&BeMessage::CommandComplete)?;
self.write_message(&BeMessage::ReadyForQuery)
self.write_message(&BeMessage::ReadyForQuery)?;
} else {
self.write_message_noflush(&BeMessage::RowDescription)?;
self.write_message_noflush(&BeMessage::DataRow)?;
self.write_message_noflush(&BeMessage::CommandComplete)?;
self.write_message(&BeMessage::ReadyForQuery)
self.write_message(&BeMessage::ReadyForQuery)?;
}
Ok(())
}
fn handle_controlfile(&mut self) -> Result<()> {
fn handle_controlfile(&mut self) -> io::Result<()> {
self.write_message_noflush(&BeMessage::RowDescription)?;
self.write_message_noflush(&BeMessage::ControlFile)?;
self.write_message_noflush(&BeMessage::CommandComplete)?;
self.write_message(&BeMessage::ReadyForQuery)
}
fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> {
fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> anyhow::Result<()> {
// Check that the timeline exists
let repository = page_cache::get_repository();
let timeline = repository
.get_timeline(timelineid)
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("client requested pagestream on timeline {} which does not exist in page server", timelineid))
})?;
let timeline = repository.get_timeline(timelineid).map_err(|_| {
anyhow!(
"client requested pagestream on timeline {} which does not exist in page server",
timelineid
)
})?;
/* switch client to COPYBOTH */
self.stream.write_u8(b'W')?;
@@ -827,14 +788,14 @@ impl Connection {
}
}
fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> {
fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> anyhow::Result<()> {
// check that the timeline exists
let repository = page_cache::get_repository();
let timeline = repository.get_or_restore_timeline(timelineid);
if timeline.is_err() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("client requested basebackup on timeline {} which does not exist in page server", timelineid)));
if repository.get_or_restore_timeline(timelineid).is_err() {
bail!(
"client requested basebackup on timeline {} which does not exist in page server",
timelineid
);
}
/* switch client to COPYOUT */
@@ -876,8 +837,8 @@ struct CopyDataSink<'a> {
stream: &'a mut BufWriter<TcpStream>,
}
impl<'a> std::io::Write for CopyDataSink<'a> {
fn write(&mut self, data: &[u8]) -> std::result::Result<usize, std::io::Error> {
impl<'a> io::Write for CopyDataSink<'a> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
// CopyData
// FIXME: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
@@ -893,7 +854,7 @@ impl<'a> std::io::Write for CopyDataSink<'a> {
Ok(data.len())
}
fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
fn flush(&mut self) -> io::Result<()> {
// no-op
Ok(())
}