mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-18 21:50:37 +00:00
simplify error handling for query encoding
This commit is contained in:
@@ -9,12 +9,14 @@ use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::task::yield_now;
|
||||
|
||||
use crate::CSafeStr;
|
||||
|
||||
const NONCE_LENGTH: usize = 24;
|
||||
|
||||
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
pub const SCRAM_SHA_256: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256");
|
||||
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
pub const SCRAM_SHA_256_PLUS: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256-PLUS");
|
||||
|
||||
// since postgres passwords are not required to exclude saslprep-prohibited
|
||||
// characters or even be valid UTF8, we run saslprep if possible and otherwise
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
|
||||
#![warn(missing_docs, clippy::all)]
|
||||
|
||||
use std::io;
|
||||
use std::{ffi::CStr, io};
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
@@ -36,20 +36,67 @@ pub enum IsNull {
|
||||
No,
|
||||
}
|
||||
|
||||
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
/// A [`std::ffi::CStr`] but without the null byte.
|
||||
#[repr(transparent)]
|
||||
#[derive(PartialEq, Eq, Hash, Debug)]
|
||||
pub struct CSafeStr([u8]);
|
||||
|
||||
impl CSafeStr {
|
||||
/// Create a new `CSafeStr`, erroring if the bytes contains a null.
|
||||
pub fn new(bytes: &[u8]) -> Result<&Self, io::Error> {
|
||||
let nul_pos = memchr::memchr(0, bytes);
|
||||
match nul_pos {
|
||||
Some(nul_pos) => Err(io::Error::other(format!(
|
||||
"unexpected null byte at position {nul_pos}"
|
||||
))),
|
||||
None => {
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
Ok(unsafe { std::mem::transmute(bytes) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `CSafeStr` up until the next null.
|
||||
pub fn take<'a>(bytes: &mut &'a [u8]) -> &'a Self {
|
||||
let nul_pos = memchr::memchr(0, bytes).unwrap_or(bytes.len());
|
||||
let bytes = bytes
|
||||
.split_off(..nul_pos)
|
||||
.expect("nul_pos should be in-bounds");
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
unsafe { std::mem::transmute(bytes) }
|
||||
}
|
||||
|
||||
/// Get the bytes of this CSafeStr.
|
||||
pub const fn as_bytes(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Create a new `CSafeStr`
|
||||
pub const fn from_cstr(s: &CStr) -> &CSafeStr {
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
unsafe { std::mem::transmute(s.to_bytes()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a CStr> for &'a CSafeStr {
|
||||
fn from(s: &'a CStr) -> &'a CSafeStr {
|
||||
CSafeStr::from_cstr(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_nullable<F>(serializer: F, buf: &mut BytesMut)
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
|
||||
E: From<io::Error>,
|
||||
F: FnOnce(&mut BytesMut) -> IsNull,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.put_i32(0);
|
||||
let size = match serializer(buf)? {
|
||||
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
|
||||
let size = match serializer(buf) {
|
||||
// this is an unreasonable enough case that I think a panic is acceptable.
|
||||
IsNull::No => i32::from_usize(buf.len() - base - 4)
|
||||
.expect("buffer size should not be larger than i32::MAX"),
|
||||
IsNull::Yes => -1,
|
||||
};
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
trait FromUsize: Sized {
|
||||
|
||||
@@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use memchr::memchr;
|
||||
|
||||
use crate::Oid;
|
||||
use crate::{CSafeStr, Oid};
|
||||
|
||||
// top-level message tags
|
||||
const PARSE_COMPLETE_TAG: u8 = b'1';
|
||||
@@ -332,25 +332,24 @@ impl AuthenticationSaslBody {
|
||||
pub struct SaslMechanisms<'a>(&'a [u8]);
|
||||
|
||||
impl<'a> FallibleIterator for SaslMechanisms<'a> {
|
||||
type Item = &'a str;
|
||||
type Item = &'a CSafeStr;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<&'a str>> {
|
||||
let value_end = find_null(self.0, 0)?;
|
||||
if value_end == 0 {
|
||||
if self.0.len() != 1 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: expected to be at end of iterator for sasl",
|
||||
));
|
||||
}
|
||||
Ok(None)
|
||||
} else {
|
||||
let value = get_str(&self.0[..value_end])?;
|
||||
self.0 = &self.0[value_end + 1..];
|
||||
Ok(Some(value))
|
||||
fn next(&mut self) -> io::Result<Option<&'a CSafeStr>> {
|
||||
if self.0 == b"0" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = CSafeStr::take(&mut self.0);
|
||||
if value.as_bytes().len() == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: expected to be at end of iterator for sasl",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,114 +1,73 @@
|
||||
//! Frontend message serialization.
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use std::error::Error;
|
||||
use std::{io, marker};
|
||||
use std::io;
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
|
||||
use crate::{FromUsize, IsNull, Oid, write_nullable};
|
||||
use crate::{CSafeStr, FromUsize, IsNull, Oid, write_nullable};
|
||||
|
||||
#[inline]
|
||||
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
|
||||
fn write_body<F>(buf: &mut BytesMut, f: F)
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
F: FnOnce(&mut BytesMut),
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 4]);
|
||||
|
||||
f(buf)?;
|
||||
f(buf);
|
||||
|
||||
let size = i32::from_usize(buf.len() - base)?;
|
||||
let size =
|
||||
i32::from_usize(buf.len() - base).expect("buffer size should not be larger than i32::MAX");
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub enum BindError {
|
||||
Conversion(Box<dyn Error + marker::Sync + Send>),
|
||||
Serialization(io::Error),
|
||||
}
|
||||
|
||||
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
|
||||
#[inline]
|
||||
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
|
||||
BindError::Conversion(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for BindError {
|
||||
#[inline]
|
||||
fn from(e: io::Error) -> BindError {
|
||||
BindError::Serialization(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn bind<I, J, F, T, K>(
|
||||
portal: &str,
|
||||
statement: &str,
|
||||
portal: &CSafeStr,
|
||||
statement: &CSafeStr,
|
||||
formats: I,
|
||||
values: J,
|
||||
mut serializer: F,
|
||||
result_formats: K,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), BindError>
|
||||
where
|
||||
) where
|
||||
I: IntoIterator<Item = i16>,
|
||||
J: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
|
||||
F: FnMut(T, &mut BytesMut) -> IsNull,
|
||||
K: IntoIterator<Item = i16>,
|
||||
{
|
||||
buf.put_u8(b'B');
|
||||
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
write_cstr(statement.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
write_cstr(portal, buf);
|
||||
write_cstr(statement, buf);
|
||||
write_counted(formats, |f, buf| buf.put_i16(f), buf);
|
||||
write_counted(
|
||||
values,
|
||||
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
|
||||
buf,
|
||||
)?;
|
||||
write_counted(
|
||||
result_formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
);
|
||||
write_counted(result_formats, |f, buf| buf.put_i16(f), buf);
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
fn write_counted<I, T, F>(items: I, mut serializer: F, buf: &mut BytesMut)
|
||||
where
|
||||
I: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
F: FnMut(T, &mut BytesMut),
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 2]);
|
||||
let mut count = 0;
|
||||
for item in items {
|
||||
serializer(item, buf)?;
|
||||
serializer(item, buf);
|
||||
count += 1;
|
||||
}
|
||||
let count = i16::from_usize(count)?;
|
||||
let count = i16::from_usize(count).expect("list should not exceed 32767 items");
|
||||
BigEndian::write_i16(&mut buf[base..], count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -117,17 +76,15 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
|
||||
buf.put_i32(80_877_102);
|
||||
buf.put_i32(process_id);
|
||||
buf.put_i32(secret_key);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn close(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'C');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
write_cstr(name, buf)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -162,85 +119,75 @@ where
|
||||
#[inline]
|
||||
pub fn copy_done(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'c');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn copy_fail(message: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'f');
|
||||
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
|
||||
write_body(buf, |buf| write_cstr(message, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn describe(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'D');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
write_cstr(name, buf)
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn execute(portal: &CSafeStr, max_rows: i32, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
write_cstr(portal, buf);
|
||||
buf.put_i32(max_rows);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
|
||||
pub fn parse<I>(name: &CSafeStr, query: &CSafeStr, param_types: I, buf: &mut BytesMut)
|
||||
where
|
||||
I: IntoIterator<Item = Oid>,
|
||||
{
|
||||
buf.put_u8(b'P');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(name.as_bytes(), buf)?;
|
||||
write_cstr(query.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
param_types,
|
||||
|t, buf| {
|
||||
buf.put_u32(t);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
Ok(())
|
||||
write_cstr(name, buf);
|
||||
write_cstr(query, buf);
|
||||
write_counted(param_types, |t, buf| buf.put_u32(t), buf);
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn password_message(password: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| write_cstr(password, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn query(query: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'Q');
|
||||
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
|
||||
write_body(buf, |buf| write_cstr(query, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn sasl_initial_response(mechanism: &CSafeStr, data: &[u8], buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(mechanism.as_bytes(), buf)?;
|
||||
let len = i32::from_usize(data.len())?;
|
||||
write_cstr(mechanism, buf);
|
||||
let len =
|
||||
i32::from_usize(data.len()).expect("sasl data should not be larger than i32::MAX");
|
||||
buf.put_i32(len);
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -248,19 +195,16 @@ pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn ssl_request(buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(80_877_103);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
// postgres protocol version 3.0(196608) in bigger-endian
|
||||
buf.put_i32(0x00_03_00_00);
|
||||
buf.put_slice(¶meters.params);
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -271,10 +215,7 @@ pub struct StartupMessageParams {
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Set parameter's value by its name.
|
||||
pub fn insert(&mut self, name: &str, value: &str) {
|
||||
if name.contains('\0') || value.contains('\0') {
|
||||
panic!("startup parameter name or value contained a null")
|
||||
}
|
||||
pub fn insert(&mut self, name: &CSafeStr, value: &CSafeStr) {
|
||||
self.params.put_slice(name.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
self.params.put_slice(value.as_bytes());
|
||||
@@ -285,24 +226,17 @@ impl StartupMessageParams {
|
||||
#[inline]
|
||||
pub fn sync(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'S');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn terminate(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'X');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
if s.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(s);
|
||||
fn write_cstr(s: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_slice(s.as_bytes());
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user