Minor cleanup in pq_proto (#2322)

This commit is contained in:
Dmitry Ivanov
2022-08-23 18:00:02 +03:00
committed by GitHub
parent 4013290508
commit 8e1d6dd848
2 changed files with 107 additions and 238 deletions

View File

@@ -163,14 +163,9 @@ pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
false
}
// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations)
// PG protocol strings are always C strings.
fn cstr_to_str(b: &Bytes) -> Result<&str> {
let without_null = if b.last() == Some(&0) {
&b[..b.len() - 1]
} else {
&b[..]
};
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
@@ -423,9 +418,9 @@ impl PostgresBackend {
self.state = ProtoState::Established;
}
FeMessage::Query(m) => {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&m.body)?;
let query_string = cstr_to_str(&body)?;
trace!("got query {:?}", query_string);
// xxx distinguish fatal and recoverable errors?

View File

@@ -25,8 +25,10 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
Query(FeQueryMessage), // Simple query
Parse(FeParseMessage), // Extended query protocol
// Simple query.
Query(Bytes),
// Extended query protocol.
Parse(FeParseMessage),
Describe(FeDescribeMessage),
Bind(FeBindMessage),
Execute(FeExecuteMessage),
@@ -69,11 +71,6 @@ impl Distribution<CancelKeyData> for Standard {
}
}
#[derive(Debug)]
pub struct FeQueryMessage {
pub body: Bytes,
}
// We only support the simple case of Parse on unnamed prepared statement and
// no params
#[derive(Debug)]
@@ -89,7 +86,7 @@ pub struct FeDescribeMessage {
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeBindMessage {}
pub struct FeBindMessage;
// we only support unnamed prepared stmt or portal
#[derive(Debug)]
@@ -100,7 +97,7 @@ pub struct FeExecuteMessage {
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeCloseMessage {}
pub struct FeCloseMessage;
/// Retry a read on EINTR
///
@@ -163,22 +160,20 @@ impl FeMessage {
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = retry_read!(stream.read_u32().await)?;
// The message length includes itself, so it better be at least 4
let bodylen = len
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)?
.checked_sub(4)
.context("invalid message length: parsing u32")?;
.context("invalid message length")?;
// Read message body
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
stream.read_exact(&mut body_buf).await?;
let body = {
let mut buffer = vec![0u8; len as usize];
stream.read_exact(&mut buffer).await?;
Bytes::from(buffer)
};
let body = Bytes::from(body_buf);
// Parse it
match tag {
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
@@ -302,124 +297,71 @@ impl FeStartupPacket {
}
impl FeParseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _pstmt_name = read_null_terminated(&mut buf)?;
let query_string = read_null_terminated(&mut buf)?;
let nparams = buf.get_i16();
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
// uses more than one prepared statement at a time.
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Parse",
));
}
*/
if nparams != 0 {
bail!("query params not implemented");
}
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
let nparams = buf.get_i16();
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_null_terminated(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Describe",
));
}
*/
if kind != b'S' {
bail!("only prepared statmement Describe is implemented");
}
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
if maxrows != 0 {
bail!("row limit in Execute message not supported");
}
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let _pstmt_name = read_null_terminated(&mut buf)?;
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented",
));
}
*/
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage {}))
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_null_terminated(&mut buf)?;
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
// FIXME: we do nothing with Close
Ok(FeMessage::Close(FeCloseMessage {}))
Ok(FeMessage::Close(FeCloseMessage))
}
}
fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let mut result = BytesMut::new();
loop {
if !buf.has_remaining() {
bail!("no null-terminator in string");
}
let byte = buf.get_u8();
if byte == 0 {
break;
}
result.put_u8(byte);
}
Ok(result.freeze())
}
// Backend
#[derive(Debug)]
@@ -441,7 +383,7 @@ pub enum BeMessage<'a> {
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
ErrorResponse(&'a str),
// single byte - used in response to SSLRequest/GSSENCRequest
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
NoData,
ParameterDescription,
@@ -554,49 +496,22 @@ pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescri
formatcode: 0,
}]);
// Safe usize -> i32|i16 conversion, from rust-postgres
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::max_value() as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i32);
/// Call f() to write body of the message and prepend it with 4-byte len as
/// prescribed by the protocol.
fn write_body<F>(buf: &mut BytesMut, f: F) -> io::Result<()>
where
F: FnOnce(&mut BytesMut) -> io::Result<()>,
{
fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let res = f(buf);
let size = i32::from_usize(buf.len() - base)?;
let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
(&mut buf[base..]).put_slice(&size.to_be_bytes());
Ok(())
res
}
/// Safe write of s into buf as cstring (String in the protocol).
pub fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -608,15 +523,11 @@ pub fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
Ok(())
}
// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations)
// PG protocol strings are always C strings.
fn cstr_to_str(b: &Bytes) -> Result<&str> {
let without_null = if b.last() == Some(&0) {
&b[..b.len() - 1]
} else {
&b[..]
};
std::str::from_utf8(without_null).map_err(|e| e.into())
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
impl<'a> BeMessage<'a> {
@@ -631,18 +542,14 @@ impl<'a> BeMessage<'a> {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(0); // Specifies that the authentication was successful.
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
});
}
BeMessage::AuthenticationCleartextPassword => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(3); // Specifies that clear text password is required.
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
});
}
BeMessage::AuthenticationMD5Password(salt) => {
@@ -650,9 +557,7 @@ impl<'a> BeMessage<'a> {
write_body(buf, |buf| {
buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
buf.put_slice(&salt[..]);
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
});
}
BeMessage::AuthenticationSasl(msg) => {
@@ -677,8 +582,7 @@ impl<'a> BeMessage<'a> {
}
}
Ok::<_, io::Error>(())
})
.unwrap()
})?;
}
BeMessage::BackendKeyData(key_data) => {
@@ -686,77 +590,64 @@ impl<'a> BeMessage<'a> {
write_body(buf, |buf| {
buf.put_i32(key_data.backend_pid);
buf.put_i32(key_data.cancel_key);
Ok(())
})
.unwrap();
});
}
BeMessage::BindComplete => {
buf.put_u8(b'2');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::CloseComplete => {
buf.put_u8(b'3');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::CommandComplete(cmd) => {
buf.put_u8(b'C');
write_body(buf, |buf| {
write_cstr(cmd, buf)?;
Ok::<_, io::Error>(())
})?;
write_body(buf, |buf| write_cstr(cmd, buf))?;
}
BeMessage::CopyData(data) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_slice(data);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
BeMessage::CopyDone => {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::CopyFail => {
buf.put_u8(b'f');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::CopyInResponse => {
buf.put_u8(b'G');
write_body(buf, |buf| {
buf.put_u8(1); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
buf.put_u8(1); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::CopyOutResponse => {
buf.put_u8(b'H');
write_body(buf, |buf| {
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
buf.put_u8(0); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::CopyBothResponse => {
buf.put_u8(b'W');
write_body(buf, |buf| {
// doesn't matter, used only for replication
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
buf.put_u8(0); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::DataRow(vals) => {
@@ -771,9 +662,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1);
}
}
Ok::<_, io::Error>(())
})
.unwrap();
});
}
// ErrorResponse is a zero-terminated array of zero-terminated fields.
@@ -788,18 +677,17 @@ impl<'a> BeMessage<'a> {
buf.put_u8(b'E');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
write_cstr(&Bytes::from("ERROR"), buf)?;
buf.put_slice(b"ERROR\0");
buf.put_u8(b'C'); // SQLSTATE error code
write_cstr(&Bytes::from("CXX000"), buf)?;
buf.put_slice(b"CXX000\0");
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})
.unwrap();
})?;
}
// NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
@@ -812,23 +700,22 @@ impl<'a> BeMessage<'a> {
buf.put_u8(b'N');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
write_cstr(&Bytes::from("NOTICE"), buf)?;
buf.put_slice(b"NOTICE\0");
buf.put_u8(b'C'); // SQLSTATE error code
write_cstr(&Bytes::from("CXX000"), buf)?;
buf.put_slice(b"CXX000\0");
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})
.unwrap();
})?;
}
BeMessage::NoData => {
buf.put_u8(b'n');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::EncryptionResponse(should_negotiate) => {
@@ -853,9 +740,7 @@ impl<'a> BeMessage<'a> {
buf.put_u8(b'S');
write_body(buf, |buf| {
buf.put_slice(&buffer[..cnt]);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
BeMessage::ParameterDescription => {
@@ -863,23 +748,19 @@ impl<'a> BeMessage<'a> {
write_body(buf, |buf| {
// we don't support params, so always 0
buf.put_i16(0);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
BeMessage::ParseComplete => {
buf.put_u8(b'1');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
BeMessage::ReadyForQuery => {
buf.put_u8(b'Z');
write_body(buf, |buf| {
buf.put_u8(b'I');
Ok::<_, io::Error>(())
})
.unwrap();
});
}
BeMessage::RowDescription(rows) => {
@@ -907,9 +788,7 @@ impl<'a> BeMessage<'a> {
buf.put_u64(body.wal_end);
buf.put_i64(body.timestamp);
buf.put_slice(body.data);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
BeMessage::KeepAlive(req) => {
@@ -918,10 +797,8 @@ impl<'a> BeMessage<'a> {
buf.put_u8(b'k');
buf.put_u64(req.sent_ptr);
buf.put_i64(req.timestamp);
buf.put_u8(if req.request_reply { 1u8 } else { 0u8 });
Ok::<_, io::Error>(())
})
.unwrap();
buf.put_u8(if req.request_reply { 1 } else { 0 });
});
}
}
Ok(())
@@ -968,17 +845,17 @@ impl ReplicationFeedback {
// value itself
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
write_cstr(&Bytes::from("current_timeline_size"), buf)?;
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
buf.put_u64(self.current_timeline_size);
write_cstr(&Bytes::from("ps_writelsn"), buf)?;
buf.put_slice(b"ps_writelsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_writelsn);
write_cstr(&Bytes::from("ps_flushlsn"), buf)?;
buf.put_slice(b"ps_flushlsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_flushlsn);
write_cstr(&Bytes::from("ps_applylsn"), buf)?;
buf.put_slice(b"ps_applylsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_applylsn);
@@ -988,7 +865,7 @@ impl ReplicationFeedback {
.expect("failed to serialize pg_replytime earlier than PG_EPOCH")
.as_micros() as i64;
write_cstr(&Bytes::from("ps_replytime"), buf)?;
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
@@ -998,33 +875,30 @@ impl ReplicationFeedback {
pub fn parse(mut buf: Bytes) -> ReplicationFeedback {
let mut zf = ReplicationFeedback::empty();
let nfields = buf.get_u8();
let mut i = 0;
while i < nfields {
i += 1;
let key_cstr = read_null_terminated(&mut buf).unwrap();
let key = cstr_to_str(&key_cstr).unwrap();
match key {
"current_timeline_size" => {
for _ in 0..nfields {
let key = read_cstr(&mut buf).unwrap();
match key.as_ref() {
b"current_timeline_size" => {
let len = buf.get_i32();
assert_eq!(len, 8);
zf.current_timeline_size = buf.get_u64();
}
"ps_writelsn" => {
b"ps_writelsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
zf.ps_writelsn = buf.get_u64();
}
"ps_flushlsn" => {
b"ps_flushlsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
zf.ps_flushlsn = buf.get_u64();
}
"ps_applylsn" => {
b"ps_applylsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
zf.ps_applylsn = buf.get_u64();
}
"ps_replytime" => {
b"ps_replytime" => {
let len = buf.get_i32();
assert_eq!(len, 8);
let raw_time = buf.get_i64();
@@ -1037,8 +911,8 @@ impl ReplicationFeedback {
_ => {
let len = buf.get_i32();
warn!(
"ReplicationFeedback parse. unknown key {} of len {}. Skip it.",
key, len
"ReplicationFeedback parse. unknown key {} of len {len}. Skip it.",
String::from_utf8_lossy(key.as_ref())
);
buf.advance(len as usize);
}
@@ -1084,7 +958,7 @@ mod tests {
*first = REPLICATION_FEEDBACK_FIELDS_NUMBER + 1;
}
write_cstr(&Bytes::from("new_field_one"), &mut data).unwrap();
data.put_slice(b"new_field_one\0");
data.put_i32(8);
data.put_u64(42);