mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 00:42:54 +00:00
simplify error handling for query encoding
This commit is contained in:
@@ -9,12 +9,14 @@ use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::task::yield_now;
|
||||
|
||||
use crate::CSafeStr;
|
||||
|
||||
const NONCE_LENGTH: usize = 24;
|
||||
|
||||
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
pub const SCRAM_SHA_256: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256");
|
||||
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
pub const SCRAM_SHA_256_PLUS: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256-PLUS");
|
||||
|
||||
// since postgres passwords are not required to exclude saslprep-prohibited
|
||||
// characters or even be valid UTF8, we run saslprep if possible and otherwise
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
|
||||
#![warn(missing_docs, clippy::all)]
|
||||
|
||||
use std::io;
|
||||
use std::{ffi::CStr, io};
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
@@ -36,20 +36,67 @@ pub enum IsNull {
|
||||
No,
|
||||
}
|
||||
|
||||
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
/// A [`std::ffi::CStr`] but without the null byte.
|
||||
#[repr(transparent)]
|
||||
#[derive(PartialEq, Eq, Hash, Debug)]
|
||||
pub struct CSafeStr([u8]);
|
||||
|
||||
impl CSafeStr {
|
||||
/// Create a new `CSafeStr`, erroring if the bytes contains a null.
|
||||
pub fn new(bytes: &[u8]) -> Result<&Self, io::Error> {
|
||||
let nul_pos = memchr::memchr(0, bytes);
|
||||
match nul_pos {
|
||||
Some(nul_pos) => Err(io::Error::other(format!(
|
||||
"unexpected null byte at position {nul_pos}"
|
||||
))),
|
||||
None => {
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
Ok(unsafe { std::mem::transmute(bytes) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `CSafeStr` up until the next null.
|
||||
pub fn take<'a>(bytes: &mut &'a [u8]) -> &'a Self {
|
||||
let nul_pos = memchr::memchr(0, bytes).unwrap_or(bytes.len());
|
||||
let bytes = bytes
|
||||
.split_off(..nul_pos)
|
||||
.expect("nul_pos should be in-bounds");
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
unsafe { std::mem::transmute(bytes) }
|
||||
}
|
||||
|
||||
/// Get the bytes of this CSafeStr.
|
||||
pub const fn as_bytes(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Create a new `CSafeStr`
|
||||
pub const fn from_cstr(s: &CStr) -> &CSafeStr {
|
||||
// Safety: CSafeStr is transparent over [u8].
|
||||
unsafe { std::mem::transmute(s.to_bytes()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a CStr> for &'a CSafeStr {
|
||||
fn from(s: &'a CStr) -> &'a CSafeStr {
|
||||
CSafeStr::from_cstr(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_nullable<F>(serializer: F, buf: &mut BytesMut)
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
|
||||
E: From<io::Error>,
|
||||
F: FnOnce(&mut BytesMut) -> IsNull,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.put_i32(0);
|
||||
let size = match serializer(buf)? {
|
||||
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
|
||||
let size = match serializer(buf) {
|
||||
// this is an unreasonable enough case that I think a panic is acceptable.
|
||||
IsNull::No => i32::from_usize(buf.len() - base - 4)
|
||||
.expect("buffer size should not be larger than i32::MAX"),
|
||||
IsNull::Yes => -1,
|
||||
};
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
trait FromUsize: Sized {
|
||||
|
||||
@@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use memchr::memchr;
|
||||
|
||||
use crate::Oid;
|
||||
use crate::{CSafeStr, Oid};
|
||||
|
||||
// top-level message tags
|
||||
const PARSE_COMPLETE_TAG: u8 = b'1';
|
||||
@@ -332,25 +332,24 @@ impl AuthenticationSaslBody {
|
||||
pub struct SaslMechanisms<'a>(&'a [u8]);
|
||||
|
||||
impl<'a> FallibleIterator for SaslMechanisms<'a> {
|
||||
type Item = &'a str;
|
||||
type Item = &'a CSafeStr;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<&'a str>> {
|
||||
let value_end = find_null(self.0, 0)?;
|
||||
if value_end == 0 {
|
||||
if self.0.len() != 1 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: expected to be at end of iterator for sasl",
|
||||
));
|
||||
}
|
||||
Ok(None)
|
||||
} else {
|
||||
let value = get_str(&self.0[..value_end])?;
|
||||
self.0 = &self.0[value_end + 1..];
|
||||
Ok(Some(value))
|
||||
fn next(&mut self) -> io::Result<Option<&'a CSafeStr>> {
|
||||
if self.0 == b"0" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = CSafeStr::take(&mut self.0);
|
||||
if value.as_bytes().len() == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: expected to be at end of iterator for sasl",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,114 +1,73 @@
|
||||
//! Frontend message serialization.
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use std::error::Error;
|
||||
use std::{io, marker};
|
||||
use std::io;
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
|
||||
use crate::{FromUsize, IsNull, Oid, write_nullable};
|
||||
use crate::{CSafeStr, FromUsize, IsNull, Oid, write_nullable};
|
||||
|
||||
#[inline]
|
||||
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
|
||||
fn write_body<F>(buf: &mut BytesMut, f: F)
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
F: FnOnce(&mut BytesMut),
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 4]);
|
||||
|
||||
f(buf)?;
|
||||
f(buf);
|
||||
|
||||
let size = i32::from_usize(buf.len() - base)?;
|
||||
let size =
|
||||
i32::from_usize(buf.len() - base).expect("buffer size should not be larger than i32::MAX");
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub enum BindError {
|
||||
Conversion(Box<dyn Error + marker::Sync + Send>),
|
||||
Serialization(io::Error),
|
||||
}
|
||||
|
||||
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
|
||||
#[inline]
|
||||
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
|
||||
BindError::Conversion(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for BindError {
|
||||
#[inline]
|
||||
fn from(e: io::Error) -> BindError {
|
||||
BindError::Serialization(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn bind<I, J, F, T, K>(
|
||||
portal: &str,
|
||||
statement: &str,
|
||||
portal: &CSafeStr,
|
||||
statement: &CSafeStr,
|
||||
formats: I,
|
||||
values: J,
|
||||
mut serializer: F,
|
||||
result_formats: K,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), BindError>
|
||||
where
|
||||
) where
|
||||
I: IntoIterator<Item = i16>,
|
||||
J: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
|
||||
F: FnMut(T, &mut BytesMut) -> IsNull,
|
||||
K: IntoIterator<Item = i16>,
|
||||
{
|
||||
buf.put_u8(b'B');
|
||||
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
write_cstr(statement.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
write_cstr(portal, buf);
|
||||
write_cstr(statement, buf);
|
||||
write_counted(formats, |f, buf| buf.put_i16(f), buf);
|
||||
write_counted(
|
||||
values,
|
||||
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
|
||||
buf,
|
||||
)?;
|
||||
write_counted(
|
||||
result_formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
);
|
||||
write_counted(result_formats, |f, buf| buf.put_i16(f), buf);
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
fn write_counted<I, T, F>(items: I, mut serializer: F, buf: &mut BytesMut)
|
||||
where
|
||||
I: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
F: FnMut(T, &mut BytesMut),
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 2]);
|
||||
let mut count = 0;
|
||||
for item in items {
|
||||
serializer(item, buf)?;
|
||||
serializer(item, buf);
|
||||
count += 1;
|
||||
}
|
||||
let count = i16::from_usize(count)?;
|
||||
let count = i16::from_usize(count).expect("list should not exceed 32767 items");
|
||||
BigEndian::write_i16(&mut buf[base..], count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -117,17 +76,15 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
|
||||
buf.put_i32(80_877_102);
|
||||
buf.put_i32(process_id);
|
||||
buf.put_i32(secret_key);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn close(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'C');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
write_cstr(name, buf)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -162,85 +119,75 @@ where
|
||||
#[inline]
|
||||
pub fn copy_done(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'c');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn copy_fail(message: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'f');
|
||||
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
|
||||
write_body(buf, |buf| write_cstr(message, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn describe(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'D');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
write_cstr(name, buf)
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn execute(portal: &CSafeStr, max_rows: i32, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
write_cstr(portal, buf);
|
||||
buf.put_i32(max_rows);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
|
||||
pub fn parse<I>(name: &CSafeStr, query: &CSafeStr, param_types: I, buf: &mut BytesMut)
|
||||
where
|
||||
I: IntoIterator<Item = Oid>,
|
||||
{
|
||||
buf.put_u8(b'P');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(name.as_bytes(), buf)?;
|
||||
write_cstr(query.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
param_types,
|
||||
|t, buf| {
|
||||
buf.put_u32(t);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
Ok(())
|
||||
write_cstr(name, buf);
|
||||
write_cstr(query, buf);
|
||||
write_counted(param_types, |t, buf| buf.put_u32(t), buf);
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn password_message(password: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| write_cstr(password, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn query(query: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_u8(b'Q');
|
||||
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
|
||||
write_body(buf, |buf| write_cstr(query, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn sasl_initial_response(mechanism: &CSafeStr, data: &[u8], buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(mechanism.as_bytes(), buf)?;
|
||||
let len = i32::from_usize(data.len())?;
|
||||
write_cstr(mechanism, buf);
|
||||
let len =
|
||||
i32::from_usize(data.len()).expect("sasl data should not be larger than i32::MAX");
|
||||
buf.put_i32(len);
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -248,19 +195,16 @@ pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn ssl_request(buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(80_877_103);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
// postgres protocol version 3.0(196608) in bigger-endian
|
||||
buf.put_i32(0x00_03_00_00);
|
||||
buf.put_slice(¶meters.params);
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -271,10 +215,7 @@ pub struct StartupMessageParams {
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Set parameter's value by its name.
|
||||
pub fn insert(&mut self, name: &str, value: &str) {
|
||||
if name.contains('\0') || value.contains('\0') {
|
||||
panic!("startup parameter name or value contained a null")
|
||||
}
|
||||
pub fn insert(&mut self, name: &CSafeStr, value: &CSafeStr) {
|
||||
self.params.put_slice(name.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
self.params.put_slice(value.as_bytes());
|
||||
@@ -285,24 +226,17 @@ impl StartupMessageParams {
|
||||
#[inline]
|
||||
pub fn sync(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'S');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn terminate(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'X');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
write_body(buf, |_| {});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
if s.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(s);
|
||||
fn write_cstr(s: &CSafeStr, buf: &mut BytesMut) {
|
||||
buf.put_slice(s.as_bytes());
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
#[doc(inline)]
|
||||
pub use postgres_protocol2::Oid;
|
||||
use postgres_protocol2::types;
|
||||
use postgres_protocol2::{IsNull, types};
|
||||
|
||||
use crate::type_gen::{Inner, Other};
|
||||
|
||||
@@ -32,36 +32,15 @@ macro_rules! accepts {
|
||||
/// All `ToSql` implementations should use this macro.
|
||||
macro_rules! to_sql_checked {
|
||||
() => {
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &$crate::Type,
|
||||
out: &mut $crate::private::BytesMut,
|
||||
) -> ::std::result::Result<
|
||||
$crate::IsNull,
|
||||
Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>,
|
||||
> {
|
||||
$crate::__to_sql_checked(self, ty, out)
|
||||
fn check(&self, ty: &Type) -> ::std::result::Result<(), $crate::WrongType> {
|
||||
if !<Self as $crate::ToSql>::accepts(ty) {
|
||||
return Err($crate::WrongType::new::<Self>(ty.clone()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// WARNING: this function is not considered part of this crate's public API.
|
||||
// It is subject to change at any time.
|
||||
#[doc(hidden)]
|
||||
pub fn __to_sql_checked<T>(
|
||||
v: &T,
|
||||
ty: &Type,
|
||||
out: &mut BytesMut,
|
||||
) -> Result<IsNull, Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
T: ToSql,
|
||||
{
|
||||
if !T::accepts(ty) {
|
||||
return Err(Box::new(WrongType::new::<T>(ty.clone())));
|
||||
}
|
||||
v.to_sql(ty, out)
|
||||
}
|
||||
|
||||
// mod pg_lsn;
|
||||
#[doc(hidden)]
|
||||
pub mod private;
|
||||
@@ -369,14 +348,6 @@ macro_rules! simple_from {
|
||||
simple_from!(i8, char_from_sql, CHAR);
|
||||
simple_from!(u32, oid_from_sql, OID);
|
||||
|
||||
/// An enum representing the nullability of a Postgres value.
|
||||
pub enum IsNull {
|
||||
/// The value is NULL.
|
||||
Yes,
|
||||
/// The value is not NULL.
|
||||
No,
|
||||
}
|
||||
|
||||
/// A trait for types that can be converted into Postgres values.
|
||||
pub trait ToSql: fmt::Debug {
|
||||
/// Converts the value of `self` into the binary format of the specified
|
||||
@@ -388,9 +359,7 @@ pub trait ToSql: fmt::Debug {
|
||||
/// The return value indicates if this value should be represented as
|
||||
/// `NULL`. If this is the case, implementations **must not** write
|
||||
/// anything to `out`.
|
||||
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized;
|
||||
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> IsNull;
|
||||
|
||||
/// Determines if a value of this type can be converted to the specified
|
||||
/// Postgres `Type`.
|
||||
@@ -402,11 +371,7 @@ pub trait ToSql: fmt::Debug {
|
||||
///
|
||||
/// *All* implementations of this method should be generated by the
|
||||
/// `to_sql_checked!()` macro.
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut BytesMut,
|
||||
) -> Result<IsNull, Box<dyn Error + Sync + Send>>;
|
||||
fn check(&self, ty: &Type) -> Result<(), WrongType>;
|
||||
|
||||
/// Specify the encode format
|
||||
fn encode_format(&self, _ty: &Type) -> Format {
|
||||
@@ -426,14 +391,14 @@ pub enum Format {
|
||||
}
|
||||
|
||||
impl ToSql for &str {
|
||||
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
|
||||
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> IsNull {
|
||||
match *ty {
|
||||
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
|
||||
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
|
||||
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
|
||||
_ => types::text_to_sql(self, w),
|
||||
}
|
||||
Ok(IsNull::No)
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
@@ -457,12 +422,9 @@ impl ToSql for &str {
|
||||
macro_rules! simple_to {
|
||||
($t:ty, $f:ident, $($expected:ident),+) => {
|
||||
impl ToSql for $t {
|
||||
fn to_sql(&self,
|
||||
_: &Type,
|
||||
w: &mut BytesMut)
|
||||
-> Result<IsNull, Box<dyn Error + Sync + Send>> {
|
||||
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> IsNull {
|
||||
types::$f(*self, w);
|
||||
Ok(IsNull::No)
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
accepts!($($expected),+);
|
||||
|
||||
@@ -219,7 +219,7 @@ impl Client {
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
frontend::query(c"ROLLBACK".into(), buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, str};
|
||||
|
||||
use postgres_protocol2::CSafeStr;
|
||||
pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
use postgres_protocol2::message::frontend::StartupMessageParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -162,7 +163,10 @@ impl Config {
|
||||
self.username = true;
|
||||
}
|
||||
|
||||
self.server_params.insert(name, value);
|
||||
self.server_params.insert(
|
||||
CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"),
|
||||
CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::task::{Context, Poll};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
@@ -122,7 +123,7 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
|
||||
frontend::startup_message(&config.server_params, &mut buf);
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
@@ -193,7 +194,7 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
|
||||
frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf);
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
@@ -214,10 +215,10 @@ where
|
||||
let mut has_scram_plus = false;
|
||||
let mut mechanisms = body.mechanisms();
|
||||
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
|
||||
match mechanism {
|
||||
sasl::SCRAM_SHA_256 => has_scram = true,
|
||||
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
|
||||
_ => {}
|
||||
if mechanism == sasl::SCRAM_SHA_256 {
|
||||
has_scram = true;
|
||||
} else if mechanism == sasl::SCRAM_SHA_256_PLUS {
|
||||
has_scram_plus = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +257,7 @@ where
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf);
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
@@ -275,7 +276,7 @@ where
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
frontend::sasl_response(scram.message(), &mut buf);
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
|
||||
@@ -123,6 +123,6 @@ pub enum SimpleQueryMessage {
|
||||
|
||||
fn slice_iter<'a>(
|
||||
s: &'a [&'a (dyn ToSql + Sync)],
|
||||
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
|
||||
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + Clone + 'a {
|
||||
s.iter().map(|s| *s as _)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::ffi::CStr;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -5,9 +6,9 @@ use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{TryStreamExt, pin_mut};
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::{CachedTypeInfo, InnerClient};
|
||||
use crate::codec::FrontendMessage;
|
||||
@@ -15,7 +16,7 @@ use crate::connection::RequestMessages;
|
||||
use crate::types::{Kind, Oid, Type};
|
||||
use crate::{Column, Error, Statement, query, slice_iter};
|
||||
|
||||
pub(crate) const TYPEINFO_QUERY: &str = "\
|
||||
pub(crate) const TYPEINFO_QUERY: &CStr = c"\
|
||||
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||
@@ -25,11 +26,11 @@ WHERE t.oid = $1
|
||||
|
||||
async fn prepare_typecheck(
|
||||
client: &Arc<InnerClient>,
|
||||
name: &'static str,
|
||||
query: &str,
|
||||
name: &'static CStr,
|
||||
query: &CSafeStr,
|
||||
types: &[Type],
|
||||
) -> Result<Statement, Error> {
|
||||
let buf = encode(client, name, query, types)?;
|
||||
let buf = encode(client, name.into(), query, types)?;
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
@@ -68,16 +69,21 @@ async fn prepare_typecheck(
|
||||
Ok(Statement::new(client, name, parameters, columns))
|
||||
}
|
||||
|
||||
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
|
||||
if types.is_empty() {
|
||||
debug!("preparing query {}: {}", name, query);
|
||||
} else {
|
||||
debug!("preparing query {} with types {:?}: {}", name, types, query);
|
||||
}
|
||||
fn encode(
|
||||
client: &InnerClient,
|
||||
name: &CSafeStr,
|
||||
query: &CSafeStr,
|
||||
types: &[Type],
|
||||
) -> Result<Bytes, Error> {
|
||||
// if types.is_empty() {
|
||||
// debug!("preparing query {}: {}", name, query);
|
||||
// } else {
|
||||
// debug!("preparing query {} with types {:?}: {}", name, types, query);
|
||||
// }
|
||||
|
||||
client.with_buf(|buf| {
|
||||
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
|
||||
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
|
||||
frontend::parse(name, query, types.iter().map(Type::oid), buf);
|
||||
frontend::describe(b'S', name, buf);
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
@@ -154,8 +160,8 @@ async fn typeinfo_statement(
|
||||
return Ok(stmt.clone());
|
||||
}
|
||||
|
||||
let typeinfo = "neon_proxy_typeinfo";
|
||||
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?;
|
||||
let typeinfo = c"neon_proxy_typeinfo";
|
||||
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?;
|
||||
|
||||
typecache.typeinfo = Some(stmt.clone());
|
||||
Ok(stmt)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::fmt;
|
||||
use std::iter;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -8,25 +8,16 @@ use bytes::{BufMut, Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Stream, ready};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use postgres_types2::{Format, ToSql, Type};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::IsNull;
|
||||
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
|
||||
|
||||
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
|
||||
|
||||
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn query<'a, I>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
@@ -34,19 +25,9 @@ pub async fn query<'a, I>(
|
||||
) -> Result<RowStream, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let buf = if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
let params = params.into_iter().collect::<Vec<_>>();
|
||||
debug!(
|
||||
"executing statement {} with parameters: {:?}",
|
||||
statement.name(),
|
||||
BorrowToSqlParamsDebug(params.as_slice()),
|
||||
);
|
||||
encode(client, &statement, params)?
|
||||
} else {
|
||||
encode(client, &statement, params)?
|
||||
};
|
||||
let buf = encode(client, &statement, params)?;
|
||||
let responses = start(client, buf).await?;
|
||||
Ok(RowStream {
|
||||
statement,
|
||||
@@ -68,45 +49,45 @@ where
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
|
||||
let params = params.into_iter();
|
||||
|
||||
let portal = c"".into(); // unnamed portal
|
||||
let statement = c"".into(); // unnamed prepared statement
|
||||
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::parse(
|
||||
"", // unnamed prepared statement
|
||||
query, // query to parse
|
||||
std::iter::empty(), // give no type info
|
||||
statement,
|
||||
query, // query to parse
|
||||
iter::empty(), // give no type info
|
||||
buf,
|
||||
)
|
||||
.map_err(Error::encode)?;
|
||||
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
|
||||
// Bind, pass params as text, retrieve as binary
|
||||
match frontend::bind(
|
||||
"", // empty string selects the unnamed portal
|
||||
"", // unnamed prepared statement
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
);
|
||||
frontend::describe(b'S', statement, buf);
|
||||
|
||||
// Bind, pass params as text, retrieve as test
|
||||
frontend::bind(
|
||||
portal,
|
||||
statement,
|
||||
iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol2::IsNull::No)
|
||||
postgres_protocol2::IsNull::No
|
||||
}
|
||||
None => Ok(postgres_protocol2::IsNull::Yes),
|
||||
None => postgres_protocol2::IsNull::Yes,
|
||||
},
|
||||
Some(0), // all text
|
||||
buf,
|
||||
) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}?;
|
||||
);
|
||||
|
||||
// Execute
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
frontend::execute(portal, 0, buf);
|
||||
// Sync
|
||||
frontend::sync(buf);
|
||||
|
||||
Ok(buf.split().freeze())
|
||||
})?;
|
||||
buf.split().freeze()
|
||||
});
|
||||
|
||||
// now read the responses
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
@@ -173,11 +154,13 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
|
||||
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let portal = c"".into(); // unnamed portal
|
||||
|
||||
client.with_buf(|buf| {
|
||||
encode_bind(statement, params, "", buf)?;
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
encode_bind(statement, params, portal, buf)?;
|
||||
frontend::execute(portal, 0, buf);
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
@@ -186,15 +169,15 @@ where
|
||||
pub fn encode_bind<'a, I>(
|
||||
statement: &Statement,
|
||||
params: I,
|
||||
portal: &str,
|
||||
portal: &CSafeStr,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let param_types = statement.params();
|
||||
let params = params.into_iter();
|
||||
let params = iter::zip(params.into_iter(), param_types);
|
||||
|
||||
assert!(
|
||||
param_types.len() == params.len(),
|
||||
@@ -203,35 +186,24 @@ where
|
||||
params.len()
|
||||
);
|
||||
|
||||
let (param_formats, params): (Vec<_>, Vec<_>) = params
|
||||
.zip(param_types.iter())
|
||||
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
|
||||
.unzip();
|
||||
// check encodings
|
||||
for (i, (p, ty)) in params.clone().enumerate() {
|
||||
p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))?
|
||||
}
|
||||
|
||||
let params = params.into_iter();
|
||||
let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16);
|
||||
|
||||
let mut error_idx = 0;
|
||||
let r = frontend::bind(
|
||||
frontend::bind(
|
||||
portal,
|
||||
statement.name(),
|
||||
param_formats,
|
||||
params.zip(param_types).enumerate(),
|
||||
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
|
||||
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
|
||||
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
|
||||
Err(e) => {
|
||||
error_idx = idx;
|
||||
Err(e)
|
||||
}
|
||||
},
|
||||
params,
|
||||
|(param, ty), buf| param.to_sql(ty, buf),
|
||||
Some(1),
|
||||
buf,
|
||||
);
|
||||
match r {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
|
||||
@@ -7,6 +7,7 @@ use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Stream, ready};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tracing::debug;
|
||||
@@ -69,8 +70,9 @@ pub async fn batch_execute(
|
||||
}
|
||||
|
||||
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
|
||||
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
|
||||
client.with_buf(|buf| {
|
||||
frontend::query(query, buf).map_err(Error::encode)?;
|
||||
frontend::query(query, buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::ffi::CStr;
|
||||
use std::fmt;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use postgres_protocol2::Oid;
|
||||
use postgres_protocol2::message::backend::Field;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use postgres_protocol2::{CSafeStr, Oid};
|
||||
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
@@ -12,7 +13,7 @@ use crate::types::Type;
|
||||
|
||||
struct StatementInner {
|
||||
client: Weak<InnerClient>,
|
||||
name: &'static str,
|
||||
name: &'static CStr,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
}
|
||||
@@ -21,7 +22,7 @@ impl Drop for StatementInner {
|
||||
fn drop(&mut self) {
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::close(b'S', self.name, buf).unwrap();
|
||||
frontend::close(b'S', self.name.into(), buf);
|
||||
frontend::sync(buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
@@ -39,7 +40,7 @@ pub struct Statement(Arc<StatementInner>);
|
||||
impl Statement {
|
||||
pub(crate) fn new(
|
||||
inner: &Arc<InnerClient>,
|
||||
name: &'static str,
|
||||
name: &'static CStr,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
) -> Statement {
|
||||
@@ -54,14 +55,14 @@ impl Statement {
|
||||
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
|
||||
Statement(Arc::new(StatementInner {
|
||||
client: Weak::new(),
|
||||
name: "<anonymous>",
|
||||
name: c"<anonymous>",
|
||||
params,
|
||||
columns,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
self.0.name
|
||||
pub(crate) fn name(&self) -> &CSafeStr {
|
||||
self.0.name.into()
|
||||
}
|
||||
|
||||
/// Returns the expected types of the statement's parameters.
|
||||
|
||||
@@ -21,7 +21,7 @@ impl Drop for Transaction<'_> {
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
frontend::query(c"ROLLBACK".into(), buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
|
||||
@@ -536,7 +536,8 @@ mod tests {
|
||||
use control_plane::AuthSecret;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||
use postgres_protocol::CSafeStr;
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256};
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
@@ -714,15 +715,15 @@ mod tests {
|
||||
// server should offer scram
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationSasl(a) => {
|
||||
let options: Vec<&str> = a.mechanisms().collect().unwrap();
|
||||
assert_eq!(options, ["SCRAM-SHA-256"]);
|
||||
let options: Vec<&CSafeStr> = a.mechanisms().collect().unwrap();
|
||||
assert_eq!(options, [SCRAM_SHA_256]);
|
||||
}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
|
||||
// client sends client-first-message
|
||||
let mut write = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
|
||||
frontend::sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut write);
|
||||
client.write_all(&write).await.unwrap();
|
||||
|
||||
// server response with server-first-message
|
||||
@@ -735,7 +736,7 @@ mod tests {
|
||||
|
||||
// client response with client-final-message
|
||||
write.clear();
|
||||
frontend::sasl_response(scram.message(), &mut write).unwrap();
|
||||
frontend::sasl_response(scram.message(), &mut write);
|
||||
client.write_all(&write).await.unwrap();
|
||||
|
||||
// server response with server-final-message
|
||||
@@ -800,7 +801,7 @@ mod tests {
|
||||
|
||||
// client responds with password
|
||||
write.clear();
|
||||
frontend::password_message(b"my-secret-password", &mut write).unwrap();
|
||||
frontend::password_message(c"my-secret-password".into(), &mut write);
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
@@ -853,8 +854,10 @@ mod tests {
|
||||
|
||||
// client responds with password
|
||||
let mut write = BytesMut::new();
|
||||
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
|
||||
.unwrap();
|
||||
frontend::password_message(
|
||||
c"endpoint=my-endpoint;my-secret-password".into(),
|
||||
&mut write,
|
||||
);
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -174,8 +173,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
|
||||
scram::SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
scram::SCRAM_SHA_256_PLUS => {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::fmt::Debug;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_client::tls::TlsConnect;
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::{authentication::sasl::SCRAM_SHA_256, message::frontend};
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
@@ -60,8 +60,7 @@ async fn proxy_mitm(
|
||||
params: startup.params.into(),
|
||||
},
|
||||
&mut buf,
|
||||
)
|
||||
.unwrap();
|
||||
);
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
@@ -90,7 +89,7 @@ async fn proxy_mitm(
|
||||
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
|
||||
frontend::sasl_initial_response(SCRAM_SHA_256, &new_message, &mut buf);
|
||||
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
|
||||
@@ -21,8 +21,8 @@ pub(crate) use key::ScramKey;
|
||||
pub(crate) use secret::ServerSecret;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
pub(crate) const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
pub(crate) const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
|
||||
Reference in New Issue
Block a user