diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 12c728f173..3a77af46bb 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -10,6 +10,7 @@ // *callmemaybe $url* -- ask pageserver to start walreceiver on $url // +use anyhow::{anyhow, bail}; use byteorder::{ReadBytesExt, WriteBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use log::*; @@ -30,8 +31,6 @@ use crate::walreceiver; use crate::PageServerConf; use crate::ZTimelineId; -type Result = std::result::Result; - #[derive(Debug)] enum FeMessage { StartupMessage(FeStartupMessage), @@ -112,7 +111,7 @@ enum StartupRequestCode { } impl FeStartupMessage { - pub fn read(stream: &mut dyn std::io::Read) -> Result> { + pub fn read(stream: &mut dyn std::io::Read) -> anyhow::Result> { const MAX_STARTUP_PACKET_LENGTH: u32 = 10000; const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; @@ -124,19 +123,12 @@ impl FeStartupMessage { // in the log if the client opens connection but closes it immediately. let len = match stream.read_u32::() { Ok(len) => len, - Err(err) => { - if err.kind() == std::io::ErrorKind::UnexpectedEof { - return Ok(None); - } else { - return Err(err); - } - } + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), }; - if len < 4 || len as u32 > MAX_STARTUP_PACKET_LENGTH { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "invalid message length", - )); + + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + bail!("invalid message length"); } let bodylen = len - 4; @@ -181,15 +173,12 @@ struct FeParseMessage { query_string: Bytes, } -fn read_null_terminated(buf: &mut Bytes) -> Result { +fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { let mut result = BytesMut::new(); loop { if !buf.has_remaining() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "no null-terminator in string", - )); + bail!("no null-terminator in string"); } let byte = buf.get_u8(); @@ -203,7 +192,7 @@ fn read_null_terminated(buf: &mut Bytes) -> Result { } impl FeParseMessage { - pub fn parse(mut buf: Bytes) -> Result { + pub fn parse(mut buf: Bytes) -> anyhow::Result { let _pstmt_name = read_null_terminated(&mut buf)?; let query_string = read_null_terminated(&mut buf)?; let nparams = buf.get_i16(); @@ -222,10 +211,7 @@ impl FeParseMessage { */ if nparams != 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "query params not implemented", - )); + bail!("query params not implemented"); } Ok(FeMessage::Parse(FeParseMessage { query_string })) @@ -239,7 +225,7 @@ struct FeDescribeMessage { } impl FeDescribeMessage { - pub fn parse(mut buf: Bytes) -> Result { + pub fn parse(mut buf: Bytes) -> anyhow::Result { let kind = buf.get_u8(); let _pstmt_name = read_null_terminated(&mut buf)?; @@ -254,10 +240,7 @@ impl FeDescribeMessage { */ if kind != b'S' { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "only prepared statmement Describe is implemented", - )); + bail!("only prepared statmement Describe is implemented"); } Ok(FeMessage::Describe(FeDescribeMessage { kind })) @@ -272,22 +255,16 @@ struct FeExecuteMessage { } impl FeExecuteMessage { - pub fn parse(mut buf: Bytes) -> Result { + pub fn parse(mut buf: Bytes) -> anyhow::Result { let portal_name = read_null_terminated(&mut buf)?; let maxrows = buf.get_i32(); if !portal_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named portals not implemented", - )); + bail!("named portals not implemented"); } if maxrows != 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "row limit in Execute message not supported", - )); + bail!("row limit in Execute message not supported"); } Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) @@ -299,15 +276,12 @@ impl FeExecuteMessage { struct FeBindMessage {} impl FeBindMessage { - pub fn parse(mut buf: Bytes) -> Result { + pub fn parse(mut buf: Bytes) -> anyhow::Result { let portal_name = read_null_terminated(&mut buf)?; let _pstmt_name = read_null_terminated(&mut buf)?; if !portal_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named portals not implemented", - )); + bail!("named portals not implemented"); } // FIXME: see FeParseMessage::parse @@ -329,7 +303,7 @@ impl FeBindMessage { struct FeCloseMessage {} impl FeCloseMessage { - pub fn parse(mut buf: Bytes) -> Result { + pub fn parse(mut buf: Bytes) -> anyhow::Result { let _kind = buf.get_u8(); let _pstmt_or_portal_name = read_null_terminated(&mut buf)?; @@ -340,28 +314,20 @@ impl FeCloseMessage { } impl FeMessage { - pub fn read(stream: &mut dyn Read) -> Result> { + pub fn read(stream: &mut dyn Read) -> anyhow::Result> { // Each libpq message begins with a message type byte, followed by message length // If the client closes the connection, return None. But if the client closes the // connection in the middle of a message, we will return an error. let tag = match stream.read_u8() { Ok(b) => b, - Err(err) => { - if err.kind() == std::io::ErrorKind::UnexpectedEof { - return Ok(None); - } else { - return Err(err); - } - } + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), }; let len = stream.read_u32::()?; // The message length includes itself, so it better be at least 4 if len < 4 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "invalid message length: parsing u32", - )); + bail!("invalid message length: parsing u32"); } let bodylen = len - 4; @@ -398,16 +364,14 @@ impl FeMessage { 0 => Ok(Some(FeMessage::ZenithExistsRequest(zreq))), 1 => Ok(Some(FeMessage::ZenithNblocksRequest(zreq))), 2 => Ok(Some(FeMessage::ZenithReadRequest(zreq))), - _ => Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown smgr message tag: {},'{:?}'", smgr_tag, body), + _ => Err(anyhow!( + "unknown smgr message tag: {},'{:?}'", + smgr_tag, + body )), } } - tag => Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown message tag: {},'{:?}'", tag, body), - )), + tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)), } } } @@ -442,7 +406,6 @@ pub fn thread_main(conf: &PageServerConf) { struct Connection { stream_in: BufReader, stream: BufWriter, - buffer: BytesMut, init_done: bool, conf: PageServerConf, } @@ -452,7 +415,6 @@ impl Connection { Connection { stream_in: BufReader::new(socket.try_clone().unwrap()), stream: BufWriter::new(socket), - buffer: BytesMut::with_capacity(10 * 1024), init_done: false, conf, } @@ -461,7 +423,7 @@ impl Connection { // // Read full message or return None if connection is closed // - fn read_message(&mut self) -> Result> { + fn read_message(&mut self) -> anyhow::Result> { if !self.init_done { FeStartupMessage::read(&mut self.stream_in) } else { @@ -596,7 +558,7 @@ impl Connection { self.stream.flush() } - fn run(&mut self) -> Result<()> { + fn run(&mut self) -> anyhow::Result<()> { let mut unnamed_query_string = Bytes::new(); loop { let msg = self.read_message()?; @@ -650,8 +612,7 @@ impl Connection { break; } x => { - error!("unexpected message type : {:?}", x); - return Err(io::Error::new(io::ErrorKind::Other, "unexpected message")); + bail!("unexpected message type : {:?}", x); } } } @@ -659,7 +620,7 @@ impl Connection { Ok(()) } - fn process_query(&mut self, query_string: Bytes) -> Result<()> { + fn process_query(&mut self, query_string: Bytes) -> anyhow::Result<()> { debug!("process query {:?}", query_string); // remove null terminator, if any @@ -669,78 +630,78 @@ impl Connection { } if query_string.starts_with(b"controlfile") { - self.handle_controlfile() + self.handle_controlfile()?; } else if query_string.starts_with(b"pagestream ") { let (_l, r) = query_string.split_at("pagestream ".len()); - let timelineid_str = String::from_utf8(r.to_vec()).unwrap(); - let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); + let timelineid_str = String::from_utf8(r.to_vec())?; + let timelineid = ZTimelineId::from_str(&timelineid_str)?; - self.handle_pagerequests(timelineid) + self.handle_pagerequests(timelineid)?; } else if query_string.starts_with(b"basebackup ") { let (_l, r) = query_string.split_at("basebackup ".len()); let r = r.to_vec(); - let timelineid_str = String::from(String::from_utf8(r).unwrap().trim_end()); + let timelineid_str = String::from(String::from_utf8(r)?.trim_end()); info!("got basebackup command: \"{}\"", timelineid_str); - let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); + let timelineid = ZTimelineId::from_str(&timelineid_str)?; // Check that the timeline exists self.handle_basebackup_request(timelineid)?; self.write_message_noflush(&BeMessage::CommandComplete)?; - self.write_message(&BeMessage::ReadyForQuery) + self.write_message(&BeMessage::ReadyForQuery)?; } else if query_string.starts_with(b"callmemaybe ") { - let query_str = String::from_utf8(query_string.to_vec()).unwrap(); + let query_str = String::from_utf8(query_string.to_vec())?; // callmemaybe + // TODO lazy static let re = Regex::new(r"^callmemaybe ([[:xdigit:]]+) (.*)$").unwrap(); - let caps = re.captures(&query_str); - let caps = caps.unwrap(); + let caps = re + .captures(&query_str) + .ok_or_else(|| anyhow!("invalid callmemaybe: '{}'", query_str))?; - let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str()).unwrap(); + let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str())?; let connstr: String = String::from(caps.get(2).unwrap().as_str()); // Check that the timeline exists let repository = page_cache::get_repository(); - let timeline = repository.get_or_restore_timeline(timelineid); - if timeline.is_err() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid))); + if repository.get_or_restore_timeline(timelineid).is_err() { + bail!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid); } walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr); self.write_message_noflush(&BeMessage::CommandComplete)?; - self.write_message(&BeMessage::ReadyForQuery) + self.write_message(&BeMessage::ReadyForQuery)?; } else if query_string.starts_with(b"status") { self.write_message_noflush(&BeMessage::RowDescription)?; self.write_message_noflush(&BeMessage::DataRow)?; self.write_message_noflush(&BeMessage::CommandComplete)?; - self.write_message(&BeMessage::ReadyForQuery) + self.write_message(&BeMessage::ReadyForQuery)?; } else { self.write_message_noflush(&BeMessage::RowDescription)?; self.write_message_noflush(&BeMessage::DataRow)?; self.write_message_noflush(&BeMessage::CommandComplete)?; - self.write_message(&BeMessage::ReadyForQuery) + self.write_message(&BeMessage::ReadyForQuery)?; } + + Ok(()) } - fn handle_controlfile(&mut self) -> Result<()> { + fn handle_controlfile(&mut self) -> io::Result<()> { self.write_message_noflush(&BeMessage::RowDescription)?; self.write_message_noflush(&BeMessage::ControlFile)?; self.write_message_noflush(&BeMessage::CommandComplete)?; self.write_message(&BeMessage::ReadyForQuery) } - fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> anyhow::Result<()> { // Check that the timeline exists let repository = page_cache::get_repository(); - let timeline = repository - .get_timeline(timelineid) - .map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("client requested pagestream on timeline {} which does not exist in page server", timelineid)) - })?; + let timeline = repository.get_timeline(timelineid).map_err(|_| { + anyhow!( + "client requested pagestream on timeline {} which does not exist in page server", + timelineid + ) + })?; /* switch client to COPYBOTH */ self.stream.write_u8(b'W')?; @@ -827,14 +788,14 @@ impl Connection { } } - fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> anyhow::Result<()> { // check that the timeline exists let repository = page_cache::get_repository(); - let timeline = repository.get_or_restore_timeline(timelineid); - if timeline.is_err() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("client requested basebackup on timeline {} which does not exist in page server", timelineid))); + if repository.get_or_restore_timeline(timelineid).is_err() { + bail!( + "client requested basebackup on timeline {} which does not exist in page server", + timelineid + ); } /* switch client to COPYOUT */ @@ -876,8 +837,8 @@ struct CopyDataSink<'a> { stream: &'a mut BufWriter, } -impl<'a> std::io::Write for CopyDataSink<'a> { - fn write(&mut self, data: &[u8]) -> std::result::Result { +impl<'a> io::Write for CopyDataSink<'a> { + fn write(&mut self, data: &[u8]) -> io::Result { // CopyData // FIXME: if the input is large, we should split it into multiple messages. // Not sure what the threshold should be, but the ultimate hard limit is that @@ -893,7 +854,7 @@ impl<'a> std::io::Write for CopyDataSink<'a> { Ok(data.len()) } - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + fn flush(&mut self) -> io::Result<()> { // no-op Ok(()) }