From 8e1d6dd848da5006f63a4a8088954ee39a3f5a05 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Tue, 23 Aug 2022 18:00:02 +0300 Subject: [PATCH] Minor cleanup in pq_proto (#2322) --- libs/utils/src/postgres_backend.rs | 15 +- libs/utils/src/pq_proto.rs | 330 +++++++++-------------------- 2 files changed, 107 insertions(+), 238 deletions(-) diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs index 4d873bd5ac..604eb75aaf 100644 --- a/libs/utils/src/postgres_backend.rs +++ b/libs/utils/src/postgres_backend.rs @@ -163,14 +163,9 @@ pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool { false } -// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations) -// PG protocol strings are always C strings. -fn cstr_to_str(b: &Bytes) -> Result<&str> { - let without_null = if b.last() == Some(&0) { - &b[..b.len() - 1] - } else { - &b[..] - }; +// Cast a byte slice to a string slice, dropping null terminator if there's one. +fn cstr_to_str(bytes: &[u8]) -> Result<&str> { + let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); std::str::from_utf8(without_null).map_err(|e| e.into()) } @@ -423,9 +418,9 @@ impl PostgresBackend { self.state = ProtoState::Established; } - FeMessage::Query(m) => { + FeMessage::Query(body) => { // remove null terminator - let query_string = cstr_to_str(&m.body)?; + let query_string = cstr_to_str(&body)?; trace!("got query {:?}", query_string); // xxx distinguish fatal and recoverable errors? diff --git a/libs/utils/src/pq_proto.rs b/libs/utils/src/pq_proto.rs index 3f14acd50d..2f8dcf31d3 100644 --- a/libs/utils/src/pq_proto.rs +++ b/libs/utils/src/pq_proto.rs @@ -25,8 +25,10 @@ pub const TEXT_OID: Oid = 25; #[derive(Debug)] pub enum FeMessage { StartupPacket(FeStartupPacket), - Query(FeQueryMessage), // Simple query - Parse(FeParseMessage), // Extended query protocol + // Simple query. + Query(Bytes), + // Extended query protocol. + Parse(FeParseMessage), Describe(FeDescribeMessage), Bind(FeBindMessage), Execute(FeExecuteMessage), @@ -69,11 +71,6 @@ impl Distribution for Standard { } } -#[derive(Debug)] -pub struct FeQueryMessage { - pub body: Bytes, -} - // We only support the simple case of Parse on unnamed prepared statement and // no params #[derive(Debug)] @@ -89,7 +86,7 @@ pub struct FeDescribeMessage { // we only support unnamed prepared stmt and portal #[derive(Debug)] -pub struct FeBindMessage {} +pub struct FeBindMessage; // we only support unnamed prepared stmt or portal #[derive(Debug)] @@ -100,7 +97,7 @@ pub struct FeExecuteMessage { // we only support unnamed prepared stmt and portal #[derive(Debug)] -pub struct FeCloseMessage {} +pub struct FeCloseMessage; /// Retry a read on EINTR /// @@ -163,22 +160,20 @@ impl FeMessage { Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(e.into()), }; - let len = retry_read!(stream.read_u32().await)?; - // The message length includes itself, so it better be at least 4 - let bodylen = len + // The message length includes itself, so it better be at least 4. + let len = retry_read!(stream.read_u32().await)? .checked_sub(4) - .context("invalid message length: parsing u32")?; + .context("invalid message length")?; - // Read message body - let mut body_buf: Vec = vec![0; bodylen as usize]; - stream.read_exact(&mut body_buf).await?; + let body = { + let mut buffer = vec![0u8; len as usize]; + stream.read_exact(&mut buffer).await?; + Bytes::from(buffer) + }; - let body = Bytes::from(body_buf); - - // Parse it match tag { - b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), + b'Q' => Ok(Some(FeMessage::Query(body))), b'P' => Ok(Some(FeParseMessage::parse(body)?)), b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), @@ -302,124 +297,71 @@ impl FeStartupPacket { } impl FeParseMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let _pstmt_name = read_null_terminated(&mut buf)?; - let query_string = read_null_terminated(&mut buf)?; - let nparams = buf.get_i16(); - + fn parse(mut buf: Bytes) -> anyhow::Result { // FIXME: the rust-postgres driver uses a named prepared statement // for copy_out(). We're not prepared to handle that correctly. For // now, just ignore the statement name, assuming that the client never // uses more than one prepared statement at a time. - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented in Parse", - )); - } - */ - if nparams != 0 { - bail!("query params not implemented"); - } + let _pstmt_name = read_cstr(&mut buf)?; + let query_string = read_cstr(&mut buf)?; + let nparams = buf.get_i16(); + + ensure!(nparams == 0, "query params not implemented"); Ok(FeMessage::Parse(FeParseMessage { query_string })) } } impl FeDescribeMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> anyhow::Result { let kind = buf.get_u8(); - let _pstmt_name = read_null_terminated(&mut buf)?; + let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented in Describe", - )); - } - */ - - if kind != b'S' { - bail!("only prepared statmement Describe is implemented"); - } + ensure!( + kind == b'S', + "only prepared statemement Describe is implemented" + ); Ok(FeMessage::Describe(FeDescribeMessage { kind })) } } impl FeExecuteMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let portal_name = read_null_terminated(&mut buf)?; + fn parse(mut buf: Bytes) -> anyhow::Result { + let portal_name = read_cstr(&mut buf)?; let maxrows = buf.get_i32(); - if !portal_name.is_empty() { - bail!("named portals not implemented"); - } - - if maxrows != 0 { - bail!("row limit in Execute message not supported"); - } + ensure!(portal_name.is_empty(), "named portals not implemented"); + ensure!(maxrows == 0, "row limit in Execute message not implemented"); Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) } } impl FeBindMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let portal_name = read_null_terminated(&mut buf)?; - let _pstmt_name = read_null_terminated(&mut buf)?; - - if !portal_name.is_empty() { - bail!("named portals not implemented"); - } + fn parse(mut buf: Bytes) -> anyhow::Result { + let portal_name = read_cstr(&mut buf)?; + let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented", - )); - } - */ + ensure!(portal_name.is_empty(), "named portals not implemented"); - Ok(FeMessage::Bind(FeBindMessage {})) + Ok(FeMessage::Bind(FeBindMessage)) } } impl FeCloseMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> anyhow::Result { let _kind = buf.get_u8(); - let _pstmt_or_portal_name = read_null_terminated(&mut buf)?; + let _pstmt_or_portal_name = read_cstr(&mut buf)?; // FIXME: we do nothing with Close - - Ok(FeMessage::Close(FeCloseMessage {})) + Ok(FeMessage::Close(FeCloseMessage)) } } -fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { - let mut result = BytesMut::new(); - - loop { - if !buf.has_remaining() { - bail!("no null-terminator in string"); - } - - let byte = buf.get_u8(); - - if byte == 0 { - break; - } - result.put_u8(byte); - } - Ok(result.freeze()) -} - // Backend #[derive(Debug)] @@ -441,7 +383,7 @@ pub enum BeMessage<'a> { // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), ErrorResponse(&'a str), - // single byte - used in response to SSLRequest/GSSENCRequest + /// Single byte - used in response to SSLRequest/GSSENCRequest. EncryptionResponse(bool), NoData, ParameterDescription, @@ -554,49 +496,22 @@ pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescri formatcode: 0, }]); -// Safe usize -> i32|i16 conversion, from rust-postgres -trait FromUsize: Sized { - fn from_usize(x: usize) -> Result; -} - -macro_rules! from_usize { - ($t:ty) => { - impl FromUsize for $t { - #[inline] - fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "value too large to transmit", - )) - } else { - Ok(x as $t) - } - } - } - }; -} - -from_usize!(i32); - /// Call f() to write body of the message and prepend it with 4-byte len as /// prescribed by the protocol. -fn write_body(buf: &mut BytesMut, f: F) -> io::Result<()> -where - F: FnOnce(&mut BytesMut) -> io::Result<()>, -{ +fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { let base = buf.len(); buf.extend_from_slice(&[0; 4]); - f(buf)?; + let res = f(buf); - let size = i32::from_usize(buf.len() - base)?; + let size = i32::try_from(buf.len() - base).expect("message too big to transmit"); (&mut buf[base..]).put_slice(&size.to_be_bytes()); - Ok(()) + + res } /// Safe write of s into buf as cstring (String in the protocol). -pub fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { if s.contains(&0) { return Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -608,15 +523,11 @@ pub fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { Ok(()) } -// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations) -// PG protocol strings are always C strings. -fn cstr_to_str(b: &Bytes) -> Result<&str> { - let without_null = if b.last() == Some(&0) { - &b[..b.len() - 1] - } else { - &b[..] - }; - std::str::from_utf8(without_null).map_err(|e| e.into()) +fn read_cstr(buf: &mut Bytes) -> anyhow::Result { + let pos = buf.iter().position(|x| *x == 0); + let result = buf.split_to(pos.context("missing terminator")?); + buf.advance(1); // drop the null terminator + Ok(result) } impl<'a> BeMessage<'a> { @@ -631,18 +542,14 @@ impl<'a> BeMessage<'a> { buf.put_u8(b'R'); write_body(buf, |buf| { buf.put_i32(0); // Specifies that the authentication was successful. - Ok::<_, io::Error>(()) - }) - .unwrap(); // write into BytesMut can't fail + }); } BeMessage::AuthenticationCleartextPassword => { buf.put_u8(b'R'); write_body(buf, |buf| { buf.put_i32(3); // Specifies that clear text password is required. - Ok::<_, io::Error>(()) - }) - .unwrap(); // write into BytesMut can't fail + }); } BeMessage::AuthenticationMD5Password(salt) => { @@ -650,9 +557,7 @@ impl<'a> BeMessage<'a> { write_body(buf, |buf| { buf.put_i32(5); // Specifies that an MD5-encrypted password is required. buf.put_slice(&salt[..]); - Ok::<_, io::Error>(()) - }) - .unwrap(); // write into BytesMut can't fail + }); } BeMessage::AuthenticationSasl(msg) => { @@ -677,8 +582,7 @@ impl<'a> BeMessage<'a> { } } Ok::<_, io::Error>(()) - }) - .unwrap() + })?; } BeMessage::BackendKeyData(key_data) => { @@ -686,77 +590,64 @@ impl<'a> BeMessage<'a> { write_body(buf, |buf| { buf.put_i32(key_data.backend_pid); buf.put_i32(key_data.cancel_key); - Ok(()) - }) - .unwrap(); + }); } BeMessage::BindComplete => { buf.put_u8(b'2'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::CloseComplete => { buf.put_u8(b'3'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::CommandComplete(cmd) => { buf.put_u8(b'C'); - write_body(buf, |buf| { - write_cstr(cmd, buf)?; - Ok::<_, io::Error>(()) - })?; + write_body(buf, |buf| write_cstr(cmd, buf))?; } BeMessage::CopyData(data) => { buf.put_u8(b'd'); write_body(buf, |buf| { buf.put_slice(data); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } BeMessage::CopyDone => { buf.put_u8(b'c'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::CopyFail => { buf.put_u8(b'f'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::CopyInResponse => { buf.put_u8(b'G'); write_body(buf, |buf| { - buf.put_u8(1); /* copy_is_binary */ - buf.put_i16(0); /* numAttributes */ - Ok::<_, io::Error>(()) - }) - .unwrap(); + buf.put_u8(1); // copy_is_binary + buf.put_i16(0); // numAttributes + }); } BeMessage::CopyOutResponse => { buf.put_u8(b'H'); write_body(buf, |buf| { - buf.put_u8(0); /* copy_is_binary */ - buf.put_i16(0); /* numAttributes */ - Ok::<_, io::Error>(()) - }) - .unwrap(); + buf.put_u8(0); // copy_is_binary + buf.put_i16(0); // numAttributes + }); } BeMessage::CopyBothResponse => { buf.put_u8(b'W'); write_body(buf, |buf| { // doesn't matter, used only for replication - buf.put_u8(0); /* copy_is_binary */ - buf.put_i16(0); /* numAttributes */ - Ok::<_, io::Error>(()) - }) - .unwrap(); + buf.put_u8(0); // copy_is_binary + buf.put_i16(0); // numAttributes + }); } BeMessage::DataRow(vals) => { @@ -771,9 +662,7 @@ impl<'a> BeMessage<'a> { buf.put_i32(-1); } } - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } // ErrorResponse is a zero-terminated array of zero-terminated fields. @@ -788,18 +677,17 @@ impl<'a> BeMessage<'a> { buf.put_u8(b'E'); write_body(buf, |buf| { buf.put_u8(b'S'); // severity - write_cstr(&Bytes::from("ERROR"), buf)?; + buf.put_slice(b"ERROR\0"); buf.put_u8(b'C'); // SQLSTATE error code - write_cstr(&Bytes::from("CXX000"), buf)?; + buf.put_slice(b"CXX000\0"); buf.put_u8(b'M'); // the message write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator Ok::<_, io::Error>(()) - }) - .unwrap(); + })?; } // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the @@ -812,23 +700,22 @@ impl<'a> BeMessage<'a> { buf.put_u8(b'N'); write_body(buf, |buf| { buf.put_u8(b'S'); // severity - write_cstr(&Bytes::from("NOTICE"), buf)?; + buf.put_slice(b"NOTICE\0"); buf.put_u8(b'C'); // SQLSTATE error code - write_cstr(&Bytes::from("CXX000"), buf)?; + buf.put_slice(b"CXX000\0"); buf.put_u8(b'M'); // the message write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator Ok::<_, io::Error>(()) - }) - .unwrap(); + })?; } BeMessage::NoData => { buf.put_u8(b'n'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::EncryptionResponse(should_negotiate) => { @@ -853,9 +740,7 @@ impl<'a> BeMessage<'a> { buf.put_u8(b'S'); write_body(buf, |buf| { buf.put_slice(&buffer[..cnt]); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } BeMessage::ParameterDescription => { @@ -863,23 +748,19 @@ impl<'a> BeMessage<'a> { write_body(buf, |buf| { // we don't support params, so always 0 buf.put_i16(0); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } BeMessage::ParseComplete => { buf.put_u8(b'1'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } BeMessage::ReadyForQuery => { buf.put_u8(b'Z'); write_body(buf, |buf| { buf.put_u8(b'I'); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } BeMessage::RowDescription(rows) => { @@ -907,9 +788,7 @@ impl<'a> BeMessage<'a> { buf.put_u64(body.wal_end); buf.put_i64(body.timestamp); buf.put_slice(body.data); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } BeMessage::KeepAlive(req) => { @@ -918,10 +797,8 @@ impl<'a> BeMessage<'a> { buf.put_u8(b'k'); buf.put_u64(req.sent_ptr); buf.put_i64(req.timestamp); - buf.put_u8(if req.request_reply { 1u8 } else { 0u8 }); - Ok::<_, io::Error>(()) - }) - .unwrap(); + buf.put_u8(if req.request_reply { 1 } else { 0 }); + }); } } Ok(()) @@ -968,17 +845,17 @@ impl ReplicationFeedback { // value itself pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> { buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys - write_cstr(&Bytes::from("current_timeline_size"), buf)?; + buf.put_slice(b"current_timeline_size\0"); buf.put_i32(8); buf.put_u64(self.current_timeline_size); - write_cstr(&Bytes::from("ps_writelsn"), buf)?; + buf.put_slice(b"ps_writelsn\0"); buf.put_i32(8); buf.put_u64(self.ps_writelsn); - write_cstr(&Bytes::from("ps_flushlsn"), buf)?; + buf.put_slice(b"ps_flushlsn\0"); buf.put_i32(8); buf.put_u64(self.ps_flushlsn); - write_cstr(&Bytes::from("ps_applylsn"), buf)?; + buf.put_slice(b"ps_applylsn\0"); buf.put_i32(8); buf.put_u64(self.ps_applylsn); @@ -988,7 +865,7 @@ impl ReplicationFeedback { .expect("failed to serialize pg_replytime earlier than PG_EPOCH") .as_micros() as i64; - write_cstr(&Bytes::from("ps_replytime"), buf)?; + buf.put_slice(b"ps_replytime\0"); buf.put_i32(8); buf.put_i64(timestamp); Ok(()) @@ -998,33 +875,30 @@ impl ReplicationFeedback { pub fn parse(mut buf: Bytes) -> ReplicationFeedback { let mut zf = ReplicationFeedback::empty(); let nfields = buf.get_u8(); - let mut i = 0; - while i < nfields { - i += 1; - let key_cstr = read_null_terminated(&mut buf).unwrap(); - let key = cstr_to_str(&key_cstr).unwrap(); - match key { - "current_timeline_size" => { + for _ in 0..nfields { + let key = read_cstr(&mut buf).unwrap(); + match key.as_ref() { + b"current_timeline_size" => { let len = buf.get_i32(); assert_eq!(len, 8); zf.current_timeline_size = buf.get_u64(); } - "ps_writelsn" => { + b"ps_writelsn" => { let len = buf.get_i32(); assert_eq!(len, 8); zf.ps_writelsn = buf.get_u64(); } - "ps_flushlsn" => { + b"ps_flushlsn" => { let len = buf.get_i32(); assert_eq!(len, 8); zf.ps_flushlsn = buf.get_u64(); } - "ps_applylsn" => { + b"ps_applylsn" => { let len = buf.get_i32(); assert_eq!(len, 8); zf.ps_applylsn = buf.get_u64(); } - "ps_replytime" => { + b"ps_replytime" => { let len = buf.get_i32(); assert_eq!(len, 8); let raw_time = buf.get_i64(); @@ -1037,8 +911,8 @@ impl ReplicationFeedback { _ => { let len = buf.get_i32(); warn!( - "ReplicationFeedback parse. unknown key {} of len {}. Skip it.", - key, len + "ReplicationFeedback parse. unknown key {} of len {len}. Skip it.", + String::from_utf8_lossy(key.as_ref()) ); buf.advance(len as usize); } @@ -1084,7 +958,7 @@ mod tests { *first = REPLICATION_FEEDBACK_FIELDS_NUMBER + 1; } - write_cstr(&Bytes::from("new_field_one"), &mut data).unwrap(); + data.put_slice(b"new_field_one\0"); data.put_i32(8); data.put_u64(42);