simplify error handling for query encoding

This commit is contained in:
Conrad Ludgate
2025-05-21 13:37:57 +01:00
parent f3c9d0adf4
commit 13d41b51a2
18 changed files with 246 additions and 313 deletions

View File

@@ -9,12 +9,14 @@ use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use tokio::task::yield_now;
use crate::CSafeStr;
const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub const SCRAM_SHA_256: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256");
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
pub const SCRAM_SHA_256_PLUS: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256-PLUS");
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise

View File

@@ -11,7 +11,7 @@
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![warn(missing_docs, clippy::all)]
use std::io;
use std::{ffi::CStr, io};
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
@@ -36,20 +36,67 @@ pub enum IsNull {
No,
}
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
/// A [`std::ffi::CStr`] but without the null byte.
#[repr(transparent)]
#[derive(PartialEq, Eq, Hash, Debug)]
pub struct CSafeStr([u8]);
impl CSafeStr {
/// Create a new `CSafeStr`, erroring if the bytes contains a null.
pub fn new(bytes: &[u8]) -> Result<&Self, io::Error> {
let nul_pos = memchr::memchr(0, bytes);
match nul_pos {
Some(nul_pos) => Err(io::Error::other(format!(
"unexpected null byte at position {nul_pos}"
))),
None => {
// Safety: CSafeStr is transparent over [u8].
Ok(unsafe { std::mem::transmute(bytes) })
}
}
}
/// Create a new `CSafeStr` up until the next null.
pub fn take<'a>(bytes: &mut &'a [u8]) -> &'a Self {
let nul_pos = memchr::memchr(0, bytes).unwrap_or(bytes.len());
let bytes = bytes
.split_off(..nul_pos)
.expect("nul_pos should be in-bounds");
// Safety: CSafeStr is transparent over [u8].
unsafe { std::mem::transmute(bytes) }
}
/// Get the bytes of this CSafeStr.
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Create a new `CSafeStr`
pub const fn from_cstr(s: &CStr) -> &CSafeStr {
// Safety: CSafeStr is transparent over [u8].
unsafe { std::mem::transmute(s.to_bytes()) }
}
}
impl<'a> From<&'a CStr> for &'a CSafeStr {
fn from(s: &'a CStr) -> &'a CSafeStr {
CSafeStr::from_cstr(s)
}
}
fn write_nullable<F>(serializer: F, buf: &mut BytesMut)
where
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
E: From<io::Error>,
F: FnOnce(&mut BytesMut) -> IsNull,
{
let base = buf.len();
buf.put_i32(0);
let size = match serializer(buf)? {
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
let size = match serializer(buf) {
// this is an unreasonable enough case that I think a panic is acceptable.
IsNull::No => i32::from_usize(buf.len() - base - 4)
.expect("buffer size should not be larger than i32::MAX"),
IsNull::Yes => -1,
};
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
trait FromUsize: Sized {

View File

@@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use crate::Oid;
use crate::{CSafeStr, Oid};
// top-level message tags
const PARSE_COMPLETE_TAG: u8 = b'1';
@@ -332,25 +332,24 @@ impl AuthenticationSaslBody {
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Item = &'a CSafeStr;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = find_null(self.0, 0)?;
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(None)
} else {
let value = get_str(&self.0[..value_end])?;
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
fn next(&mut self) -> io::Result<Option<&'a CSafeStr>> {
if self.0 == b"0" {
return Ok(None);
}
let value = CSafeStr::take(&mut self.0);
if value.as_bytes().len() == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(Some(value))
}
}

View File

@@ -1,114 +1,73 @@
//! Frontend message serialization.
#![allow(missing_docs)]
use std::error::Error;
use std::{io, marker};
use std::io;
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use crate::{FromUsize, IsNull, Oid, write_nullable};
use crate::{CSafeStr, FromUsize, IsNull, Oid, write_nullable};
#[inline]
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
fn write_body<F>(buf: &mut BytesMut, f: F)
where
F: FnOnce(&mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
F: FnOnce(&mut BytesMut),
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
f(buf);
let size = i32::from_usize(buf.len() - base)?;
let size =
i32::from_usize(buf.len() - base).expect("buffer size should not be larger than i32::MAX");
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
pub enum BindError {
Conversion(Box<dyn Error + marker::Sync + Send>),
Serialization(io::Error),
}
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
#[inline]
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
BindError::Conversion(e)
}
}
impl From<io::Error> for BindError {
#[inline]
fn from(e: io::Error) -> BindError {
BindError::Serialization(e)
}
}
#[inline]
pub fn bind<I, J, F, T, K>(
portal: &str,
statement: &str,
portal: &CSafeStr,
statement: &CSafeStr,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut BytesMut,
) -> Result<(), BindError>
where
) where
I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
F: FnMut(T, &mut BytesMut) -> IsNull,
K: IntoIterator<Item = i16>,
{
buf.put_u8(b'B');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(statement.as_bytes(), buf)?;
write_counted(
formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
write_cstr(portal, buf);
write_cstr(statement, buf);
write_counted(formats, |f, buf| buf.put_i16(f), buf);
write_counted(
values,
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(
result_formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
);
write_counted(result_formats, |f, buf| buf.put_i16(f), buf);
})
}
#[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
fn write_counted<I, T, F>(items: I, mut serializer: F, buf: &mut BytesMut)
where
I: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
F: FnMut(T, &mut BytesMut),
{
let base = buf.len();
buf.extend_from_slice(&[0; 2]);
let mut count = 0;
for item in items {
serializer(item, buf)?;
serializer(item, buf);
count += 1;
}
let count = i16::from_usize(count)?;
let count = i16::from_usize(count).expect("list should not exceed 32767 items");
BigEndian::write_i16(&mut buf[base..], count);
Ok(())
}
#[inline]
@@ -117,17 +76,15 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
buf.put_i32(80_877_102);
buf.put_i32(process_id);
buf.put_i32(secret_key);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn close(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'C');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
write_cstr(name, buf)
})
}
@@ -162,85 +119,75 @@ where
#[inline]
pub fn copy_done(buf: &mut BytesMut) {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn copy_fail(message: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'f');
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
write_body(buf, |buf| write_cstr(message, buf))
}
#[inline]
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn describe(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
write_cstr(name, buf)
})
}
#[inline]
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
pub fn execute(portal: &CSafeStr, max_rows: i32, buf: &mut BytesMut) {
buf.put_u8(b'E');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(portal, buf);
buf.put_i32(max_rows);
Ok(())
})
}
#[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
pub fn parse<I>(name: &CSafeStr, query: &CSafeStr, param_types: I, buf: &mut BytesMut)
where
I: IntoIterator<Item = Oid>,
{
buf.put_u8(b'P');
write_body(buf, |buf| {
write_cstr(name.as_bytes(), buf)?;
write_cstr(query.as_bytes(), buf)?;
write_counted(
param_types,
|t, buf| {
buf.put_u32(t);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
write_cstr(name, buf);
write_cstr(query, buf);
write_counted(param_types, |t, buf| buf.put_u32(t), buf);
})
}
#[inline]
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn password_message(password: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| write_cstr(password, buf))
}
#[inline]
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn query(query: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'Q');
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
write_body(buf, |buf| write_cstr(query, buf))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn sasl_initial_response(mechanism: &CSafeStr, data: &[u8], buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| {
write_cstr(mechanism.as_bytes(), buf)?;
let len = i32::from_usize(data.len())?;
write_cstr(mechanism, buf);
let len =
i32::from_usize(data.len()).expect("sasl data should not be larger than i32::MAX");
buf.put_i32(len);
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| {
buf.put_slice(data);
Ok(())
})
}
@@ -248,19 +195,16 @@ pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn ssl_request(buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_103);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
#[inline]
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}
@@ -271,10 +215,7 @@ pub struct StartupMessageParams {
impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') || value.contains('\0') {
panic!("startup parameter name or value contained a null")
}
pub fn insert(&mut self, name: &CSafeStr, value: &CSafeStr) {
self.params.put_slice(name.as_bytes());
self.params.put_u8(0);
self.params.put_slice(value.as_bytes());
@@ -285,24 +226,17 @@ impl StartupMessageParams {
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
pub fn terminate(buf: &mut BytesMut) {
buf.put_u8(b'X');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
fn write_cstr(s: &CSafeStr, buf: &mut BytesMut) {
buf.put_slice(s.as_bytes());
buf.put_u8(0);
Ok(())
}

View File

@@ -13,7 +13,7 @@ use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
#[doc(inline)]
pub use postgres_protocol2::Oid;
use postgres_protocol2::types;
use postgres_protocol2::{IsNull, types};
use crate::type_gen::{Inner, Other};
@@ -32,36 +32,15 @@ macro_rules! accepts {
/// All `ToSql` implementations should use this macro.
macro_rules! to_sql_checked {
() => {
fn to_sql_checked(
&self,
ty: &$crate::Type,
out: &mut $crate::private::BytesMut,
) -> ::std::result::Result<
$crate::IsNull,
Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>,
> {
$crate::__to_sql_checked(self, ty, out)
fn check(&self, ty: &Type) -> ::std::result::Result<(), $crate::WrongType> {
if !<Self as $crate::ToSql>::accepts(ty) {
return Err($crate::WrongType::new::<Self>(ty.clone()));
}
Ok(())
}
};
}
// WARNING: this function is not considered part of this crate's public API.
// It is subject to change at any time.
#[doc(hidden)]
pub fn __to_sql_checked<T>(
v: &T,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
T: ToSql,
{
if !T::accepts(ty) {
return Err(Box::new(WrongType::new::<T>(ty.clone())));
}
v.to_sql(ty, out)
}
// mod pg_lsn;
#[doc(hidden)]
pub mod private;
@@ -369,14 +348,6 @@ macro_rules! simple_from {
simple_from!(i8, char_from_sql, CHAR);
simple_from!(u32, oid_from_sql, OID);
/// An enum representing the nullability of a Postgres value.
pub enum IsNull {
/// The value is NULL.
Yes,
/// The value is not NULL.
No,
}
/// A trait for types that can be converted into Postgres values.
pub trait ToSql: fmt::Debug {
/// Converts the value of `self` into the binary format of the specified
@@ -388,9 +359,7 @@ pub trait ToSql: fmt::Debug {
/// The return value indicates if this value should be represented as
/// `NULL`. If this is the case, implementations **must not** write
/// anything to `out`.
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
Self: Sized;
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> IsNull;
/// Determines if a value of this type can be converted to the specified
/// Postgres `Type`.
@@ -402,11 +371,7 @@ pub trait ToSql: fmt::Debug {
///
/// *All* implementations of this method should be generated by the
/// `to_sql_checked!()` macro.
fn to_sql_checked(
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>;
fn check(&self, ty: &Type) -> Result<(), WrongType>;
/// Specify the encode format
fn encode_format(&self, _ty: &Type) -> Format {
@@ -426,14 +391,14 @@ pub enum Format {
}
impl ToSql for &str {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> IsNull {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
_ => types::text_to_sql(self, w),
}
Ok(IsNull::No)
IsNull::No
}
fn accepts(ty: &Type) -> bool {
@@ -457,12 +422,9 @@ impl ToSql for &str {
macro_rules! simple_to {
($t:ty, $f:ident, $($expected:ident),+) => {
impl ToSql for $t {
fn to_sql(&self,
_: &Type,
w: &mut BytesMut)
-> Result<IsNull, Box<dyn Error + Sync + Send>> {
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> IsNull {
types::$f(*self, w);
Ok(IsNull::No)
IsNull::No
}
accepts!($($expected),+);

View File

@@ -219,7 +219,7 @@ impl Client {
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self

View File

@@ -4,6 +4,7 @@ use std::net::IpAddr;
use std::time::Duration;
use std::{fmt, str};
use postgres_protocol2::CSafeStr;
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use postgres_protocol2::message::frontend::StartupMessageParams;
use serde::{Deserialize, Serialize};
@@ -162,7 +163,10 @@ impl Config {
self.username = true;
}
self.server_params.insert(name, value);
self.server_params.insert(
CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"),
CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"),
);
self
}

View File

@@ -6,6 +6,7 @@ use std::task::{Context, Poll};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
@@ -122,7 +123,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
frontend::startup_message(&config.server_params, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -193,7 +194,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -214,10 +215,10 @@ where
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
if mechanism == sasl::SCRAM_SHA_256 {
has_scram = true;
} else if mechanism == sasl::SCRAM_SHA_256_PLUS {
has_scram_plus = true;
}
}
@@ -256,7 +257,7 @@ where
};
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
@@ -275,7 +276,7 @@ where
.map_err(|e| Error::authentication(e.into()))?;
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_response(scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await

View File

@@ -123,6 +123,6 @@ pub enum SimpleQueryMessage {
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + Clone + 'a {
s.iter().map(|s| *s as _)
}

View File

@@ -1,3 +1,4 @@
use std::ffi::CStr;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
@@ -5,9 +6,9 @@ use std::sync::Arc;
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{TryStreamExt, pin_mut};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tracing::debug;
use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage;
@@ -15,7 +16,7 @@ use crate::connection::RequestMessages;
use crate::types::{Kind, Oid, Type};
use crate::{Column, Error, Statement, query, slice_iter};
pub(crate) const TYPEINFO_QUERY: &str = "\
pub(crate) const TYPEINFO_QUERY: &CStr = c"\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
@@ -25,11 +26,11 @@ WHERE t.oid = $1
async fn prepare_typecheck(
client: &Arc<InnerClient>,
name: &'static str,
query: &str,
name: &'static CStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Statement, Error> {
let buf = encode(client, name, query, types)?;
let buf = encode(client, name.into(), query, types)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
@@ -68,16 +69,21 @@ async fn prepare_typecheck(
Ok(Statement::new(client, name, parameters, columns))
}
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
debug!("preparing query {} with types {:?}: {}", name, types, query);
}
fn encode(
client: &InnerClient,
name: &CSafeStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Bytes, Error> {
// if types.is_empty() {
// debug!("preparing query {}: {}", name, query);
// } else {
// debug!("preparing query {} with types {:?}: {}", name, types, query);
// }
client.with_buf(|buf| {
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
frontend::parse(name, query, types.iter().map(Type::oid), buf);
frontend::describe(b'S', name, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
@@ -154,8 +160,8 @@ async fn typeinfo_statement(
return Ok(stmt.clone());
}
let typeinfo = "neon_proxy_typeinfo";
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?;
let typeinfo = c"neon_proxy_typeinfo";
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?;
typecache.typeinfo = Some(stmt.clone());
Ok(stmt)

View File

@@ -1,4 +1,4 @@
use std::fmt;
use std::iter;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
@@ -8,25 +8,16 @@ use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::{Format, ToSql, Type};
use tracing::debug;
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::IsNull;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
@@ -34,19 +25,9 @@ pub async fn query<'a, I>(
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let buf = if tracing::enabled!(tracing::Level::DEBUG) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let buf = encode(client, &statement, params)?;
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
@@ -68,45 +49,45 @@ where
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
let params = params.into_iter();
let portal = c"".into(); // unnamed portal
let statement = c"".into(); // unnamed prepared statement
let buf = client.with_buf(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
std::iter::empty(), // give no type info
statement,
query, // query to parse
iter::empty(), // give no type info
buf,
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
);
frontend::describe(b'S', statement, buf);
// Bind, pass params as text, retrieve as test
frontend::bind(
portal,
statement,
iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol2::IsNull::No)
postgres_protocol2::IsNull::No
}
None => Ok(postgres_protocol2::IsNull::Yes),
None => postgres_protocol2::IsNull::Yes,
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}?;
);
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::execute(portal, 0, buf);
// Sync
frontend::sync(buf);
Ok(buf.split().freeze())
})?;
buf.split().freeze()
});
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
@@ -173,11 +154,13 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let portal = c"".into(); // unnamed portal
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
encode_bind(statement, params, portal, buf)?;
frontend::execute(portal, 0, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
@@ -186,15 +169,15 @@ where
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &str,
portal: &CSafeStr,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let param_types = statement.params();
let params = params.into_iter();
let params = iter::zip(params.into_iter(), param_types);
assert!(
param_types.len() == params.len(),
@@ -203,35 +186,24 @@ where
params.len()
);
let (param_formats, params): (Vec<_>, Vec<_>) = params
.zip(param_types.iter())
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
.unzip();
// check encodings
for (i, (p, ty)) in params.clone().enumerate() {
p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))?
}
let params = params.into_iter();
let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16);
let mut error_idx = 0;
let r = frontend::bind(
frontend::bind(
portal,
statement.name(),
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
Err(e) => {
error_idx = idx;
Err(e)
}
},
params,
|(param, ty), buf| param.to_sql(ty, buf),
Some(1),
buf,
);
match r {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}
Ok(())
}
pin_project! {

View File

@@ -7,6 +7,7 @@ use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tracing::debug;
@@ -69,8 +70,9 @@ pub async fn batch_execute(
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
frontend::query(query, buf);
Ok(buf.split().freeze())
})
}

View File

@@ -1,9 +1,10 @@
use std::ffi::CStr;
use std::fmt;
use std::sync::{Arc, Weak};
use postgres_protocol2::Oid;
use postgres_protocol2::message::backend::Field;
use postgres_protocol2::message::frontend;
use postgres_protocol2::{CSafeStr, Oid};
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
@@ -12,7 +13,7 @@ use crate::types::Type;
struct StatementInner {
client: Weak<InnerClient>,
name: &'static str,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
}
@@ -21,7 +22,7 @@ impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', self.name, buf).unwrap();
frontend::close(b'S', self.name.into(), buf);
frontend::sync(buf);
buf.split().freeze()
});
@@ -39,7 +40,7 @@ pub struct Statement(Arc<StatementInner>);
impl Statement {
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: &'static str,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
@@ -54,14 +55,14 @@ impl Statement {
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: "<anonymous>",
name: c"<anonymous>",
params,
columns,
}))
}
pub(crate) fn name(&self) -> &str {
self.0.name
pub(crate) fn name(&self) -> &CSafeStr {
self.0.name.into()
}
/// Returns the expected types of the statement's parameters.

View File

@@ -21,7 +21,7 @@ impl Drop for Transaction<'_> {
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self

View File

@@ -536,7 +536,8 @@ mod tests {
use control_plane::AuthSecret;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use postgres_protocol::CSafeStr;
use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256};
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
@@ -714,15 +715,15 @@ mod tests {
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
let options: Vec<&CSafeStr> = a.mechanisms().collect().unwrap();
assert_eq!(options, [SCRAM_SHA_256]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-first-message
@@ -735,7 +736,7 @@ mod tests {
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
frontend::sasl_response(scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-final-message
@@ -800,7 +801,7 @@ mod tests {
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
frontend::password_message(c"my-secret-password".into(), &mut write);
client.write_all(&write).await.unwrap();
});
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -853,8 +854,10 @@ mod tests {
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
frontend::password_message(
c"endpoint=my-endpoint;my-secret-password".into(),
&mut write,
);
client.write_all(&write).await.unwrap();
});

View File

@@ -3,7 +3,6 @@
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -174,8 +173,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
scram::SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
scram::SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}

View File

@@ -9,7 +9,7 @@ use std::fmt::Debug;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_client::tls::TlsConnect;
use postgres_protocol::message::frontend;
use postgres_protocol::{authentication::sasl::SCRAM_SHA_256, message::frontend};
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
@@ -60,8 +60,7 @@ async fn proxy_mitm(
params: startup.params.into(),
},
&mut buf,
)
.unwrap();
);
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
@@ -90,7 +89,7 @@ async fn proxy_mitm(
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, &new_message, &mut buf);
end_server.send(buf.freeze()).await.unwrap();
continue;

View File

@@ -21,8 +21,8 @@ pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
pub(crate) const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub(crate) const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];