chore(proxy): vendor a subset of rust-postgres (#9930)

Our rust-postgres fork is getting messy. Mostly because proxy wants more
control over the raw protocol than tokio-postgres provides. As such,
it's diverging more and more. Storage and compute also make use of
rust-postgres, but in more normal usage, thus they don't need our crazy
changes.

Idea: 
* proxy maintains their subset
* other teams use a minimal patch set against upstream rust-postgres

Reviewing this code will be difficult. To implement it, I
1. Copied tokio-postgres, postgres-protocol and postgres-types from
00940fcdb5
2. Updated their package names with the `2` suffix to make them compile
in the workspace.
3. Updated proxy to use those packages
4. Copied in the code from tokio-postgres-rustls 0.13 (with some patches
applied https://github.com/jbg/tokio-postgres-rustls/pull/32
https://github.com/jbg/tokio-postgres-rustls/pull/33)
5. Removed as much dead code as I could find in the vendored libraries
6. Updated the tokio-postgres-rustls code to use our existing channel
binding implementation
This commit is contained in:
Conrad Ludgate
2024-11-29 11:08:01 +00:00
committed by GitHub
parent 3ffe6de0b9
commit 1d642d6a57
58 changed files with 11199 additions and 26 deletions

View File

@@ -0,0 +1,21 @@
[package]
name = "postgres-protocol2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
base64 = "0.20"
byteorder.workspace = true
bytes.workspace = true
fallible-iterator.workspace = true
hmac.workspace = true
md-5 = "0.10"
memchr = "2.0"
rand.workspace = true
sha2.workspace = true
stringprep = "0.1"
tokio = { workspace = true, features = ["rt"] }
[dev-dependencies]
tokio = { workspace = true, features = ["full"] }

View File

@@ -0,0 +1,37 @@
//! Authentication protocol support.
use md5::{Digest, Md5};
pub mod sasl;
/// Hashes authentication information in a way suitable for use in response
/// to an `AuthenticationMd5Password` message.
///
/// The resulting string should be sent back to the database in a
/// `PasswordMessage` message.
#[inline]
pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String {
let mut md5 = Md5::new();
md5.update(password);
md5.update(username);
let output = md5.finalize_reset();
md5.update(format!("{:x}", output));
md5.update(salt);
format!("md5{:x}", md5.finalize())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn md5() {
let username = b"md5_user";
let password = b"password";
let salt = [0x2a, 0x3d, 0x8f, 0xe0];
assert_eq!(
md5_hash(username, password, salt),
"md562af4dd09bbb41884907a838a3233294"
);
}
}

View File

@@ -0,0 +1,516 @@
//! SASL-based authentication support.
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use std::fmt::Write;
use std::io;
use std::iter;
use std::mem;
use std::str;
use tokio::task::yield_now;
const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise
// return the raw password.
fn normalize(pass: &[u8]) -> Vec<u8> {
let pass = match str::from_utf8(pass) {
Ok(pass) => pass,
Err(_) => return pass.to_vec(),
};
match stringprep::saslprep(pass) {
Ok(pass) => pass.into_owned().into_bytes(),
Err(_) => pass.as_bytes().to_vec(),
}
}
pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
hmac.update(salt);
hmac.update(&[0, 0, 0, 1]);
let mut prev = hmac.finalize().into_bytes();
let mut hi = prev;
for i in 1..iterations {
let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
hmac.update(&prev);
prev = hmac.finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
enum ChannelBindingInner {
Unrequested,
Unsupported,
TlsServerEndPoint(Vec<u8>),
}
/// The channel binding configuration for a SCRAM authentication exchange.
pub struct ChannelBinding(ChannelBindingInner);
impl ChannelBinding {
/// The server did not request channel binding.
pub fn unrequested() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unrequested)
}
/// The server requested channel binding but the client is unable to provide it.
pub fn unsupported() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unsupported)
}
/// The server requested channel binding and the client will use the `tls-server-end-point`
/// method.
pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
}
fn gs2_header(&self) -> &'static str {
match self.0 {
ChannelBindingInner::Unrequested => "y,,",
ChannelBindingInner::Unsupported => "n,,",
ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
}
}
fn cbind_data(&self) -> &[u8] {
match self.0 {
ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
}
}
}
/// A pair of keys for the SCRAM-SHA-256 mechanism.
/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScramKeys<const N: usize> {
/// Used by server to authenticate client.
pub client_key: [u8; N],
/// Used by client to verify server's signature.
pub server_key: [u8; N],
}
/// Password or keys which were derived from it.
enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(Box<ScramKeys<N>>),
}
enum State {
Update {
nonce: String,
password: Credentials<32>,
channel_binding: ChannelBinding,
},
Finish {
server_key: [u8; 32],
auth_message: String,
},
Done,
}
/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
/// process.
///
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
///
/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
///
/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
/// passed to the `update()` method, after which the buffer returned by the `message()` method
/// should be sent to the backend in a `SASLResponse` message.
///
/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
/// to the `finish()` method, after which the authentication process is complete.
pub struct ScramSha256 {
message: String,
state: State,
}
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}
v as char
})
.collect()
}
impl ScramSha256 {
/// Constructs a new instance which will use the provided password for authentication.
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Password(normalize(password));
ScramSha256::new_inner(password, channel_binding, nonce())
}
/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys.into());
ScramSha256::new_inner(password, channel_binding, nonce())
}
fn new_inner(
password: Credentials<32>,
channel_binding: ChannelBinding,
nonce: String,
) -> ScramSha256 {
ScramSha256 {
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
state: State::Update {
nonce,
password,
channel_binding,
},
}
}
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
pub fn message(&self) -> &[u8] {
if let State::Done = self.state {
panic!("invalid SCRAM state");
}
self.message.as_bytes()
}
/// Updates the state machine with the response from the backend.
///
/// This should be called when an `AuthenticationSASLContinue` message is received.
pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
let (client_nonce, password, channel_binding) =
match mem::replace(&mut self.state, State::Done) {
State::Update {
nonce,
password,
channel_binding,
} => (nonce, password, channel_binding),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_first_message()?;
if !parsed.nonce.starts_with(&client_nonce) {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
}
let (client_key, server_key) = match password {
Credentials::Password(password) => {
let salt = match base64::decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let salted_password = hi(&password, &salt, parsed.iteration_count).await;
let make_key = |name| {
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(name);
let mut key = [0u8; 32];
key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
key
};
(make_key(b"Client Key"), make_key(b"Server Key"))
}
Credentials::Keys(keys) => (keys.client_key, keys.server_key),
};
let mut hash = Sha256::default();
hash.update(client_key);
let stored_key = hash.finalize_fixed();
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
let client_signature = hmac.finalize().into_bytes();
let mut client_proof = client_key;
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
*proof ^= signature;
}
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
self.state = State::Finish {
server_key,
auth_message,
};
Ok(())
}
/// Finalizes the authentication process.
///
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
/// Authentication has only succeeded if this method returns `Ok(())`.
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
State::Finish {
server_key,
auth_message,
} => (server_key, auth_message),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_final_message()?;
let verifier = match parsed {
ServerFinalMessage::Error(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("SCRAM error: {}", e),
));
}
ServerFinalMessage::Verifier(verifier) => verifier,
};
let verifier = match base64::decode(verifier) {
Ok(verifier) => verifier,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
hmac.verify_slice(&verifier)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
}
}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
}
impl<'a> Parser<'a> {
fn new(s: &'a str) -> Parser<'a> {
Parser {
s,
it: s.char_indices().peekable(),
}
}
fn eat(&mut self, target: char) -> io::Result<()> {
match self.it.next() {
Some((_, c)) if c == target => Ok(()),
Some((i, c)) => {
let m = format!(
"unexpected character at byte {}: expected `{}` but got `{}",
i, target, c
);
Err(io::Error::new(io::ErrorKind::InvalidInput, m))
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
where
F: Fn(char) -> bool,
{
let start = match self.it.peek() {
Some(&(i, _)) => i,
None => return Ok(""),
};
loop {
match self.it.peek() {
Some(&(_, c)) if f(c) => {
self.it.next();
}
Some(&(i, _)) => return Ok(&self.s[start..i]),
None => return Ok(&self.s[start..]),
}
}
}
fn printable(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
}
fn nonce(&mut self) -> io::Result<&'a str> {
self.eat('r')?;
self.eat('=')?;
self.printable()
}
fn base64(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
}
fn salt(&mut self) -> io::Result<&'a str> {
self.eat('s')?;
self.eat('=')?;
self.base64()
}
fn posit_number(&mut self) -> io::Result<u32> {
let n = self.take_while(|c| c.is_ascii_digit())?;
n.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}
fn iteration_count(&mut self) -> io::Result<u32> {
self.eat('i')?;
self.eat('=')?;
self.posit_number()
}
fn eof(&mut self) -> io::Result<()> {
match self.it.peek() {
Some(&(i, _)) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unexpected trailing data at byte {}", i),
)),
None => Ok(()),
}
}
fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
let nonce = self.nonce()?;
self.eat(',')?;
let salt = self.salt()?;
self.eat(',')?;
let iteration_count = self.iteration_count()?;
self.eof()?;
Ok(ServerFirstMessage {
nonce,
salt,
iteration_count,
})
}
fn value(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\0' | '=' | ','))
}
fn server_error(&mut self) -> io::Result<Option<&'a str>> {
match self.it.peek() {
Some(&(_, 'e')) => {}
_ => return Ok(None),
}
self.eat('e')?;
self.eat('=')?;
self.value().map(Some)
}
fn verifier(&mut self) -> io::Result<&'a str> {
self.eat('v')?;
self.eat('=')?;
self.base64()
}
fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
let message = match self.server_error()? {
Some(error) => ServerFinalMessage::Error(error),
None => ServerFinalMessage::Verifier(self.verifier()?),
};
self.eof()?;
Ok(message)
}
}
struct ServerFirstMessage<'a> {
nonce: &'a str,
salt: &'a str,
iteration_count: u32,
}
enum ServerFinalMessage<'a> {
Error(&'a str),
Verifier(&'a str),
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_server_first_message() {
let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
let message = Parser::new(message).server_first_message().unwrap();
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
assert_eq!(message.iteration_count, 4096);
}
// recorded auth exchange from psql
#[tokio::test]
async fn exchange() {
let password = "foobar";
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
let server_first =
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
=4096";
let client_final =
"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
let mut scram = ScramSha256::new_inner(
Credentials::Password(normalize(password.as_bytes())),
ChannelBinding::unsupported(),
nonce.to_string(),
);
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
scram.update(server_first.as_bytes()).await.unwrap();
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
scram.finish(server_final.as_bytes()).unwrap();
}
}

View File

@@ -0,0 +1,93 @@
//! Provides functions for escaping literals and identifiers for use
//! in SQL queries.
//!
//! Prefer parameterized queries where possible. Do not escape
//! parameters in a parameterized query.
#[cfg(test)]
mod test;
/// Escape a literal and surround result with single quotes. Not
/// recommended in most cases.
///
/// If input contains backslashes, result will be of the form `
/// E'...'` so it is safe to use regardless of the setting of
/// standard_conforming_strings.
pub fn escape_literal(input: &str) -> String {
escape_internal(input, false)
}
/// Escape an identifier and surround result with double quotes.
pub fn escape_identifier(input: &str) -> String {
escape_internal(input, true)
}
// Translation of PostgreSQL libpq's PQescapeInternal(). Does not
// require a connection because input string is known to be valid
// UTF-8.
//
// Escape arbitrary strings. If as_ident is true, we escape the
// result as an identifier; if false, as a literal. The result is
// returned in a newly allocated buffer. If we fail due to an
// encoding violation or out of memory condition, we return NULL,
// storing an error message into conn.
fn escape_internal(input: &str, as_ident: bool) -> String {
let mut num_backslashes = 0;
let mut num_quotes = 0;
let quote_char = if as_ident { '"' } else { '\'' };
// Scan the string for characters that must be escaped.
for ch in input.chars() {
if ch == quote_char {
num_quotes += 1;
} else if ch == '\\' {
num_backslashes += 1;
}
}
// Allocate output String.
let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL
if !as_ident && num_backslashes > 0 {
result_size += num_backslashes + 2;
}
let mut output = String::with_capacity(result_size);
// If we are escaping a literal that contains backslashes, we use
// the escape string syntax so that the result is correct under
// either value of standard_conforming_strings. We also emit a
// leading space in this case, to guard against the possibility
// that the result might be interpolated immediately following an
// identifier.
if !as_ident && num_backslashes > 0 {
output.push(' ');
output.push('E');
}
// Opening quote.
output.push(quote_char);
// Use fast path if possible.
//
// We've already verified that the input string is well-formed in
// the current encoding. If it contains no quotes and, in the
// case of literal-escaping, no backslashes, then we can just copy
// it directly to the output buffer, adding the necessary quotes.
//
// If not, we must rescan the input and process each character
// individually.
if num_quotes == 0 && (num_backslashes == 0 || as_ident) {
output.push_str(input);
} else {
for ch in input.chars() {
if ch == quote_char || (!as_ident && ch == '\\') {
output.push(ch);
}
output.push(ch);
}
}
output.push(quote_char);
output
}

View File

@@ -0,0 +1,17 @@
use crate::escape::{escape_identifier, escape_literal};
#[test]
fn test_escape_idenifier() {
assert_eq!(escape_identifier("foo"), String::from("\"foo\""));
assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\""));
assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\""));
assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\""));
}
#[test]
fn test_escape_literal() {
assert_eq!(escape_literal("foo"), String::from("'foo'"));
assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'"));
assert_eq!(escape_literal("f'oo"), String::from("'f''oo'"));
assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'"));
}

View File

@@ -0,0 +1,78 @@
//! Low level Postgres protocol APIs.
//!
//! This crate implements the low level components of Postgres's communication
//! protocol, including message and value serialization and deserialization.
//! It is designed to be used as a building block by higher level APIs such as
//! `rust-postgres`, and should not typically be used directly.
//!
//! # Note
//!
//! This library assumes that the `client_encoding` backend parameter has been
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")]
#![warn(missing_docs, rust_2018_idioms, clippy::all)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use std::io;
pub mod authentication;
pub mod escape;
pub mod message;
pub mod password;
pub mod types;
/// A Postgres OID.
pub type Oid = u32;
/// A Postgres Log Sequence Number (LSN).
pub type Lsn = u64;
/// An enum indicating if a value is `NULL` or not.
pub enum IsNull {
/// The value is `NULL`.
Yes,
/// The value is not `NULL`.
No,
}
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
E: From<io::Error>,
{
let base = buf.len();
buf.put_i32(0);
let size = match serializer(buf)? {
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
IsNull::Yes => -1,
};
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::MAX as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i16);
from_usize!(i32);

View File

@@ -0,0 +1,766 @@
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use std::cmp;
use std::io::{self, Read};
use std::ops::Range;
use std::str;
use crate::Oid;
// top-level message tags
const PARSE_COMPLETE_TAG: u8 = b'1';
const BIND_COMPLETE_TAG: u8 = b'2';
const CLOSE_COMPLETE_TAG: u8 = b'3';
pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A';
const COPY_DONE_TAG: u8 = b'c';
const COMMAND_COMPLETE_TAG: u8 = b'C';
const COPY_DATA_TAG: u8 = b'd';
const DATA_ROW_TAG: u8 = b'D';
const ERROR_RESPONSE_TAG: u8 = b'E';
const COPY_IN_RESPONSE_TAG: u8 = b'G';
const COPY_OUT_RESPONSE_TAG: u8 = b'H';
const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
pub const NOTICE_RESPONSE_TAG: u8 = b'N';
const AUTHENTICATION_TAG: u8 = b'R';
const PORTAL_SUSPENDED_TAG: u8 = b's';
pub const PARAMETER_STATUS_TAG: u8 = b'S';
const PARAMETER_DESCRIPTION_TAG: u8 = b't';
const ROW_DESCRIPTION_TAG: u8 = b'T';
pub const READY_FOR_QUERY_TAG: u8 = b'Z';
#[derive(Debug, Copy, Clone)]
pub struct Header {
tag: u8,
len: i32,
}
#[allow(clippy::len_without_is_empty)]
impl Header {
#[inline]
pub fn parse(buf: &[u8]) -> io::Result<Option<Header>> {
if buf.len() < 5 {
return Ok(None);
}
let tag = buf[0];
let len = BigEndian::read_i32(&buf[1..]);
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: header length < 4",
));
}
Ok(Some(Header { tag, len }))
}
#[inline]
pub fn tag(self) -> u8 {
self.tag
}
#[inline]
pub fn len(self) -> i32 {
self.len
}
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
AuthenticationKerberosV5,
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
AuthenticationOk,
AuthenticationScmCredential,
AuthenticationSspi,
AuthenticationGssContinue,
AuthenticationSasl(AuthenticationSaslBody),
AuthenticationSaslContinue(AuthenticationSaslContinueBody),
AuthenticationSaslFinal(AuthenticationSaslFinalBody),
BackendKeyData(BackendKeyDataBody),
BindComplete,
CloseComplete,
CommandComplete(CommandCompleteBody),
CopyData,
CopyDone,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
NoData,
NoticeResponse(NoticeResponseBody),
NotificationResponse(NotificationResponseBody),
ParameterDescription(ParameterDescriptionBody),
ParameterStatus(ParameterStatusBody),
ParseComplete,
PortalSuspended,
ReadyForQuery(ReadyForQueryBody),
RowDescription(RowDescriptionBody),
}
impl Message {
#[inline]
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
));
}
let total_len = len as usize + 1;
if buf.len() < total_len {
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let mut buf = Buffer {
bytes: buf.split_to(total_len).freeze(),
idx: 5,
};
let message = match tag {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
Message::CommandComplete(CommandCompleteBody { tag })
}
COPY_DATA_TAG => Message::CopyData,
DATA_ROW_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::DataRow(DataRowBody { storage, len })
}
ERROR_RESPONSE_TAG => {
let storage = buf.read_all();
Message::ErrorResponse(ErrorResponseBody { storage })
}
COPY_IN_RESPONSE_TAG => Message::CopyInResponse,
COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse,
COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse,
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let secret_key = buf.read_i32::<BigEndian>()?;
Message::BackendKeyData(BackendKeyDataBody {
process_id,
secret_key,
})
}
NO_DATA_TAG => Message::NoData,
NOTICE_RESPONSE_TAG => {
let storage = buf.read_all();
Message::NoticeResponse(NoticeResponseBody { storage })
}
AUTHENTICATION_TAG => match buf.read_i32::<BigEndian>()? {
0 => Message::AuthenticationOk,
2 => Message::AuthenticationKerberosV5,
3 => Message::AuthenticationCleartextPassword,
5 => {
let mut salt = [0; 4];
buf.read_exact(&mut salt)?;
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt })
}
6 => Message::AuthenticationScmCredential,
7 => Message::AuthenticationGss,
8 => Message::AuthenticationGssContinue,
9 => Message::AuthenticationSspi,
10 => {
let storage = buf.read_all();
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
}
11 => {
let storage = buf.read_all();
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
}
12 => {
let storage = buf.read_all();
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown authentication tag `{}`", tag),
));
}
},
PORTAL_SUSPENDED_TAG => Message::PortalSuspended,
PARAMETER_STATUS_TAG => {
let name = buf.read_cstr()?;
let value = buf.read_cstr()?;
Message::ParameterStatus(ParameterStatusBody { name, value })
}
PARAMETER_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody { storage, len })
}
ROW_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody { storage, len })
}
READY_FOR_QUERY_TAG => {
let status = buf.read_u8()?;
Message::ReadyForQuery(ReadyForQueryBody { status })
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown message tag `{}`", tag),
));
}
};
if !buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: expected buffer to be empty",
));
}
Ok(Some(message))
}
}
struct Buffer {
bytes: Bytes,
idx: usize,
}
impl Buffer {
#[inline]
fn slice(&self) -> &[u8] {
&self.bytes[self.idx..]
}
#[inline]
fn is_empty(&self) -> bool {
self.slice().is_empty()
}
#[inline]
fn read_cstr(&mut self) -> io::Result<Bytes> {
match memchr(0, self.slice()) {
Some(pos) => {
let start = self.idx;
let end = start + pos;
let cstr = self.bytes.slice(start..end);
self.idx = end + 1;
Ok(cstr)
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn read_all(&mut self) -> Bytes {
let buf = self.bytes.slice(self.idx..);
self.idx = self.bytes.len();
buf
}
}
impl Read for Buffer {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = {
let slice = self.slice();
let len = cmp::min(slice.len(), buf.len());
buf[..len].copy_from_slice(&slice[..len]);
len
};
self.idx += len;
Ok(len)
}
}
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
impl AuthenticationMd5PasswordBody {
#[inline]
pub fn salt(&self) -> [u8; 4] {
self.salt
}
}
pub struct AuthenticationSaslBody(Bytes);
impl AuthenticationSaslBody {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = find_null(self.0, 0)?;
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(None)
} else {
let value = get_str(&self.0[..value_end])?;
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
}
}
}
pub struct AuthenticationSaslContinueBody(Bytes);
impl AuthenticationSaslContinueBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct AuthenticationSaslFinalBody(Bytes);
impl AuthenticationSaslFinalBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
}
impl BackendKeyDataBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn secret_key(&self) -> i32 {
self.secret_key
}
}
pub struct CommandCompleteBody {
tag: Bytes,
}
impl CommandCompleteBody {
#[inline]
pub fn tag(&self) -> io::Result<&str> {
get_str(&self.tag)
}
}
#[derive(Debug)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
}
impl DataRowBody {
#[inline]
pub fn ranges(&self) -> DataRowRanges<'_> {
DataRowRanges {
buf: &self.storage,
len: self.storage.len(),
remaining: self.len,
}
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.storage
}
}
pub struct DataRowRanges<'a> {
buf: &'a [u8],
len: usize,
remaining: u16,
}
impl FallibleIterator for DataRowRanges<'_> {
type Item = Option<Range<usize>>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: datarowrange is not empty",
));
}
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
if len < 0 {
Ok(Some(None))
} else {
let len = len as usize;
if self.buf.len() < len {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
let base = self.len - self.buf.len();
self.buf = &self.buf[len..];
Ok(Some(Some(base..base + len)))
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ErrorResponseBody {
storage: Bytes,
}
impl ErrorResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
}
pub struct ErrorFields<'a> {
buf: &'a [u8],
}
impl<'a> FallibleIterator for ErrorFields<'a> {
type Item = ErrorField<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<ErrorField<'a>>> {
let type_ = self.buf.read_u8()?;
if type_ == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: error fields is not drained",
));
}
}
let value_end = find_null(self.buf, 0)?;
let value = get_str(&self.buf[..value_end])?;
self.buf = &self.buf[value_end + 1..];
Ok(Some(ErrorField { type_, value }))
}
}
pub struct ErrorField<'a> {
type_: u8,
value: &'a str,
}
impl ErrorField<'_> {
#[inline]
pub fn type_(&self) -> u8 {
self.type_
}
#[inline]
pub fn value(&self) -> &str {
self.value
}
}
pub struct NoticeResponseBody {
storage: Bytes,
}
impl NoticeResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
}
impl ParameterDescriptionBody {
#[inline]
pub fn parameters(&self) -> Parameters<'_> {
Parameters {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
}
impl FallibleIterator for Parameters<'_> {
type Item = Oid;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Oid>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parameters is not drained",
));
}
}
self.remaining -= 1;
self.buf.read_u32::<BigEndian>().map(Some)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
}
impl ParameterStatusBody {
#[inline]
pub fn name(&self) -> io::Result<&str> {
get_str(&self.name)
}
#[inline]
pub fn value(&self) -> io::Result<&str> {
get_str(&self.value)
}
}
pub struct ReadyForQueryBody {
status: u8,
}
impl ReadyForQueryBody {
#[inline]
pub fn status(&self) -> u8 {
self.status
}
}
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
}
impl RowDescriptionBody {
#[inline]
pub fn fields(&self) -> Fields<'_> {
Fields {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Fields<'a> {
buf: &'a [u8],
remaining: u16,
}
impl<'a> FallibleIterator for Fields<'a> {
type Item = Field<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Field<'a>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: field is not drained",
));
}
}
self.remaining -= 1;
let name_end = find_null(self.buf, 0)?;
let name = get_str(&self.buf[..name_end])?;
self.buf = &self.buf[name_end + 1..];
let table_oid = self.buf.read_u32::<BigEndian>()?;
let column_id = self.buf.read_i16::<BigEndian>()?;
let type_oid = self.buf.read_u32::<BigEndian>()?;
let type_size = self.buf.read_i16::<BigEndian>()?;
let type_modifier = self.buf.read_i32::<BigEndian>()?;
let format = self.buf.read_i16::<BigEndian>()?;
Ok(Some(Field {
name,
table_oid,
column_id,
type_oid,
type_size,
type_modifier,
format,
}))
}
}
pub struct Field<'a> {
name: &'a str,
table_oid: Oid,
column_id: i16,
type_oid: Oid,
type_size: i16,
type_modifier: i32,
format: i16,
}
impl<'a> Field<'a> {
#[inline]
pub fn name(&self) -> &'a str {
self.name
}
#[inline]
pub fn table_oid(&self) -> Oid {
self.table_oid
}
#[inline]
pub fn column_id(&self) -> i16 {
self.column_id
}
#[inline]
pub fn type_oid(&self) -> Oid {
self.type_oid
}
#[inline]
pub fn type_size(&self) -> i16 {
self.type_size
}
#[inline]
pub fn type_modifier(&self) -> i32 {
self.type_modifier
}
#[inline]
pub fn format(&self) -> i16 {
self.format
}
}
#[inline]
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
match memchr(0, &buf[start..]) {
Some(pos) => Ok(pos + start),
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn get_str(buf: &[u8]) -> io::Result<&str> {
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}

View File

@@ -0,0 +1,297 @@
//! Frontend message serialization.
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
use crate::{write_nullable, FromUsize, IsNull, Oid};
#[inline]
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
pub enum BindError {
Conversion(Box<dyn Error + marker::Sync + Send>),
Serialization(io::Error),
}
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
#[inline]
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
BindError::Conversion(e)
}
}
impl From<io::Error> for BindError {
#[inline]
fn from(e: io::Error) -> BindError {
BindError::Serialization(e)
}
}
#[inline]
pub fn bind<I, J, F, T, K>(
portal: &str,
statement: &str,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut BytesMut,
) -> Result<(), BindError>
where
I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
K: IntoIterator<Item = i16>,
{
buf.put_u8(b'B');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(statement.as_bytes(), buf)?;
write_counted(
formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
write_counted(
values,
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(
result_formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
I: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 2]);
let mut count = 0;
for item in items {
serializer(item, buf)?;
count += 1;
}
let count = i16::from_usize(count)?;
BigEndian::write_i16(&mut buf[base..], count);
Ok(())
}
#[inline]
pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_102);
buf.put_i32(process_id);
buf.put_i32(secret_key);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'C');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
pub struct CopyData<T> {
buf: T,
len: i32,
}
impl<T> CopyData<T>
where
T: Buf,
{
pub fn new(buf: T) -> io::Result<CopyData<T>> {
let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;
Ok(CopyData { buf, len })
}
pub fn write(self, out: &mut BytesMut) {
out.put_u8(b'd');
out.put_i32(self.len);
out.put(self.buf);
}
}
#[inline]
pub fn copy_done(buf: &mut BytesMut) {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'f');
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
}
#[inline]
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
#[inline]
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'E');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
buf.put_i32(max_rows);
Ok(())
})
}
#[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = Oid>,
{
buf.put_u8(b'P');
write_body(buf, |buf| {
write_cstr(name.as_bytes(), buf)?;
write_cstr(query.as_bytes(), buf)?;
write_counted(
param_types,
|t, buf| {
buf.put_u32(t);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| write_cstr(password, buf))
}
#[inline]
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'Q');
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
write_cstr(mechanism.as_bytes(), buf)?;
let len = i32::from_usize(data.len())?;
buf.put_i32(len);
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn ssl_request(buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_103);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = (&'a str, &'a str)>,
{
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
for (key, value) in parameters {
write_cstr(key.as_bytes(), buf)?;
write_cstr(value.as_bytes(), buf)?;
}
buf.put_u8(0);
Ok(())
})
}
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn terminate(buf: &mut BytesMut) {
buf.put_u8(b'X');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! Postgres message protocol support.
//!
//! See [Postgres's documentation][docs] for more information on message flow.
//!
//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html
pub mod backend;
pub mod frontend;

View File

@@ -0,0 +1,107 @@
//! Functions to encrypt a password in the client.
//!
//! This is intended to be used by client applications that wish to
//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password
//! need not be sent in cleartext if it is encrypted on the client
//! side. This is good because it ensures the cleartext password won't
//! end up in logs pg_stat displays, etc.
use crate::authentication::sasl;
use hmac::{Hmac, Mac};
use md5::Md5;
use rand::RngCore;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
#[cfg(test)]
mod test;
const SCRAM_DEFAULT_ITERATIONS: u32 = 4096;
const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// Hash password using SCRAM-SHA-256 with a randomly-generated
/// salt.
///
/// The client may assume the returned string doesn't contain any
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}
// Internal implementation of scram_sha_256 with a caller-provided
// salt. This is useful for testing.
pub(crate) async fn scram_sha_256_salt(
password: &[u8],
salt: [u8; SCRAM_DEFAULT_SALT_LEN],
) -> String {
// Prepare the password, per [RFC
// 4013](https://tools.ietf.org/html/rfc4013), if possible.
//
// Postgres treats passwords as byte strings (without embedded NUL
// bytes), but SASL expects passwords to be valid UTF-8.
//
// Follow the behavior of libpq's PQencryptPasswordConn(), and
// also the backend. If the password is not valid UTF-8, or if it
// contains prohibited characters (such as non-ASCII whitespace),
// just skip the SASLprep step and use the original byte
// sequence.
let prepared: Vec<u8> = match std::str::from_utf8(password) {
Ok(password_str) => {
match stringprep::saslprep(password_str) {
Ok(p) => p.into_owned().into_bytes(),
// contains invalid characters; skip saslprep
Err(_) => Vec::from(password),
}
}
// not valid UTF-8; skip saslprep
Err(_) => Vec::from(password),
};
// salt password
let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await;
// client key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Client Key");
let client_key = hmac.finalize().into_bytes();
// stored key
let mut hash = Sha256::default();
hash.update(client_key.as_slice());
let stored_key = hash.finalize_fixed();
// server key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Server Key");
let server_key = hmac.finalize().into_bytes();
format!(
"SCRAM-SHA-256${}:{}${}:{}",
SCRAM_DEFAULT_ITERATIONS,
base64::encode(salt),
base64::encode(stored_key),
base64::encode(server_key)
)
}
/// **Not recommended, as MD5 is not considered to be secure.**
///
/// Hash password using MD5 with the username as the salt.
///
/// The client may assume the returned string doesn't contain any
/// special characters that would require escaping.
pub fn md5(password: &[u8], username: &str) -> String {
// salt password with username
let mut salted_password = Vec::from(password);
salted_password.extend_from_slice(username.as_bytes());
let mut hash = Md5::new();
hash.update(&salted_password);
let digest = hash.finalize();
format!("md5{:x}", digest)
}

View File

@@ -0,0 +1,19 @@
use crate::password;
#[tokio::test]
async fn test_encrypt_scram_sha_256() {
// Specify the salt to make the test deterministic. Any bytes will do.
let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
assert_eq!(
password::scram_sha_256_salt(b"secret", salt).await,
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
);
}
#[test]
fn test_encrypt_md5() {
assert_eq!(
password::md5(b"secret", "foo"),
"md54ab2c5d00339c4b2a4e921d2dc4edec7"
);
}

View File

@@ -0,0 +1,294 @@
//! Conversions to and from Postgres's binary format for various types.
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use std::boxed::Box as StdBox;
use std::error::Error;
use std::str;
use crate::Oid;
#[cfg(test)]
mod test;
/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_to_sql(v: &str, buf: &mut BytesMut) {
buf.put_slice(v.as_bytes());
}
/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
Ok(str::from_utf8(buf)?)
}
/// Deserializes a `"char"` value.
#[inline]
pub fn char_from_sql(mut buf: &[u8]) -> Result<i8, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_i8()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// Serializes an `OID` value.
#[inline]
pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) {
buf.put_u32(v);
}
/// Deserializes an `OID` value.
#[inline]
pub fn oid_from_sql(mut buf: &[u8]) -> Result<Oid, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_u32::<BigEndian>()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// A fallible iterator over `HSTORE` entries.
pub struct HstoreEntries<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for HstoreEntries<'a> {
type Item = (&'a str, Option<&'a str>);
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
#[allow(clippy::type_complexity)]
fn next(
&mut self,
) -> Result<Option<(&'a str, Option<&'a str>)>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid buffer size".into());
}
return Ok(None);
}
self.remaining -= 1;
let key_len = self.buf.read_i32::<BigEndian>()?;
if key_len < 0 {
return Err("invalid key length".into());
}
let (key, buf) = self.buf.split_at(key_len as usize);
let key = str::from_utf8(key)?;
self.buf = buf;
let value_len = self.buf.read_i32::<BigEndian>()?;
let value = if value_len < 0 {
None
} else {
let (value, buf) = self.buf.split_at(value_len as usize);
let value = str::from_utf8(value)?;
self.buf = buf;
Some(value)
};
Ok(Some((key, value)))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Deserializes an array value.
#[inline]
pub fn array_from_sql(mut buf: &[u8]) -> Result<Array<'_>, StdBox<dyn Error + Sync + Send>> {
let dimensions = buf.read_i32::<BigEndian>()?;
if dimensions < 0 {
return Err("invalid dimension count".into());
}
let mut r = buf;
let mut elements = 1i32;
for _ in 0..dimensions {
let len = r.read_i32::<BigEndian>()?;
if len < 0 {
return Err("invalid dimension size".into());
}
let _lower_bound = r.read_i32::<BigEndian>()?;
elements = match elements.checked_mul(len) {
Some(elements) => elements,
None => return Err("too many array elements".into()),
};
}
if dimensions == 0 {
elements = 0;
}
Ok(Array {
dimensions,
elements,
buf,
})
}
/// A Postgres array.
pub struct Array<'a> {
dimensions: i32,
elements: i32,
buf: &'a [u8],
}
impl<'a> Array<'a> {
/// Returns an iterator over the dimensions of the array.
#[inline]
pub fn dimensions(&self) -> ArrayDimensions<'a> {
ArrayDimensions(&self.buf[..self.dimensions as usize * 8])
}
/// Returns an iterator over the values of the array.
#[inline]
pub fn values(&self) -> ArrayValues<'a> {
ArrayValues {
remaining: self.elements,
buf: &self.buf[self.dimensions as usize * 8..],
}
}
}
/// An iterator over the dimensions of an array.
pub struct ArrayDimensions<'a>(&'a [u8]);
impl FallibleIterator for ArrayDimensions<'_> {
type Item = ArrayDimension;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<ArrayDimension>, StdBox<dyn Error + Sync + Send>> {
if self.0.is_empty() {
return Ok(None);
}
let len = self.0.read_i32::<BigEndian>()?;
let lower_bound = self.0.read_i32::<BigEndian>()?;
Ok(Some(ArrayDimension { len, lower_bound }))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.0.len() / 8;
(len, Some(len))
}
}
/// Information about a dimension of an array.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ArrayDimension {
/// The length of this dimension.
pub len: i32,
/// The base value used to index into this dimension.
pub lower_bound: i32,
}
/// An iterator over the values of an array, in row-major order.
pub struct ArrayValues<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for ArrayValues<'a> {
type Item = Option<&'a [u8]>;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<Option<&'a [u8]>>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid message length: arrayvalue not drained".into());
}
return Ok(None);
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
let val = if len < 0 {
None
} else {
if self.buf.len() < len as usize {
return Err("invalid value length".into());
}
let (val, buf) = self.buf.split_at(len as usize);
self.buf = buf;
Some(val)
};
Ok(Some(val))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Serializes a Postgres ltree string
#[inline]
pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltree string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltree string
#[inline]
pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltree per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltree version 1 only supported".into()),
}
}
/// Serializes a Postgres lquery string
#[inline]
pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an lquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres lquery string
#[inline]
pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the lquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("lquery version 1 only supported".into()),
}
}
/// Serializes a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltxtquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltxtquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltxtquery version 1 only supported".into()),
}
}

View File

@@ -0,0 +1,87 @@
use bytes::{Buf, BytesMut};
use super::*;
#[test]
fn ltree_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltree_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltree_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}
#[test]
fn lquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
lquery_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn lquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_ok())
}
#[test]
fn lquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_err())
}
#[test]
fn ltxtquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("a & b*", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltxtquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltxtquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}