diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 2daf9a80d4..a4e5c3ad7d 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -9,12 +9,14 @@ use sha2::digest::FixedOutput; use sha2::{Digest, Sha256}; use tokio::task::yield_now; +use crate::CSafeStr; + const NONCE_LENGTH: usize = 24; /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism. -pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +pub const SCRAM_SHA_256: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256"); /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism. -pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; +pub const SCRAM_SHA_256_PLUS: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256-PLUS"); // since postgres passwords are not required to exclude saslprep-prohibited // characters or even be valid UTF8, we run saslprep if possible and otherwise diff --git a/libs/proxy/postgres-protocol2/src/lib.rs b/libs/proxy/postgres-protocol2/src/lib.rs index afbd1e92bd..3914b4e12c 100644 --- a/libs/proxy/postgres-protocol2/src/lib.rs +++ b/libs/proxy/postgres-protocol2/src/lib.rs @@ -11,7 +11,7 @@ //! set to `UTF8`. It will most likely not behave properly if that is not the case. #![warn(missing_docs, clippy::all)] -use std::io; +use std::{ffi::CStr, io}; use byteorder::{BigEndian, ByteOrder}; use bytes::{BufMut, BytesMut}; @@ -36,20 +36,67 @@ pub enum IsNull { No, } -fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), E> +/// A [`std::ffi::CStr`] but without the null byte. +#[repr(transparent)] +#[derive(PartialEq, Eq, Hash, Debug)] +pub struct CSafeStr([u8]); + +impl CSafeStr { + /// Create a new `CSafeStr`, erroring if the bytes contains a null. + pub fn new(bytes: &[u8]) -> Result<&Self, io::Error> { + let nul_pos = memchr::memchr(0, bytes); + match nul_pos { + Some(nul_pos) => Err(io::Error::other(format!( + "unexpected null byte at position {nul_pos}" + ))), + None => { + // Safety: CSafeStr is transparent over [u8]. + Ok(unsafe { std::mem::transmute(bytes) }) + } + } + } + + /// Create a new `CSafeStr` up until the next null. + pub fn take<'a>(bytes: &mut &'a [u8]) -> &'a Self { + let nul_pos = memchr::memchr(0, bytes).unwrap_or(bytes.len()); + let bytes = bytes + .split_off(..nul_pos) + .expect("nul_pos should be in-bounds"); + // Safety: CSafeStr is transparent over [u8]. + unsafe { std::mem::transmute(bytes) } + } + + /// Get the bytes of this CSafeStr. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Create a new `CSafeStr` + pub const fn from_cstr(s: &CStr) -> &CSafeStr { + // Safety: CSafeStr is transparent over [u8]. + unsafe { std::mem::transmute(s.to_bytes()) } + } +} + +impl<'a> From<&'a CStr> for &'a CSafeStr { + fn from(s: &'a CStr) -> &'a CSafeStr { + CSafeStr::from_cstr(s) + } +} + +fn write_nullable(serializer: F, buf: &mut BytesMut) where - F: FnOnce(&mut BytesMut) -> Result, - E: From, + F: FnOnce(&mut BytesMut) -> IsNull, { let base = buf.len(); buf.put_i32(0); - let size = match serializer(buf)? { - IsNull::No => i32::from_usize(buf.len() - base - 4)?, + let size = match serializer(buf) { + // this is an unreasonable enough case that I think a panic is acceptable. + IsNull::No => i32::from_usize(buf.len() - base - 4) + .expect("buffer size should not be larger than i32::MAX"), IsNull::Yes => -1, }; BigEndian::write_i32(&mut buf[base..], size); - - Ok(()) } trait FromUsize: Sized { diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index d7eaef9509..7ddf9a573f 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use memchr::memchr; -use crate::Oid; +use crate::{CSafeStr, Oid}; // top-level message tags const PARSE_COMPLETE_TAG: u8 = b'1'; @@ -332,25 +332,24 @@ impl AuthenticationSaslBody { pub struct SaslMechanisms<'a>(&'a [u8]); impl<'a> FallibleIterator for SaslMechanisms<'a> { - type Item = &'a str; + type Item = &'a CSafeStr; type Error = io::Error; #[inline] - fn next(&mut self) -> io::Result> { - let value_end = find_null(self.0, 0)?; - if value_end == 0 { - if self.0.len() != 1 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "invalid message length: expected to be at end of iterator for sasl", - )); - } - Ok(None) - } else { - let value = get_str(&self.0[..value_end])?; - self.0 = &self.0[value_end + 1..]; - Ok(Some(value)) + fn next(&mut self) -> io::Result> { + if self.0 == b"0" { + return Ok(None); } + + let value = CSafeStr::take(&mut self.0); + if value.as_bytes().len() == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: expected to be at end of iterator for sasl", + )); + } + + Ok(Some(value)) } } diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs index b447290ea8..29aef994d1 100644 --- a/libs/proxy/postgres-protocol2/src/message/frontend.rs +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -1,114 +1,73 @@ //! Frontend message serialization. #![allow(missing_docs)] -use std::error::Error; -use std::{io, marker}; +use std::io; use byteorder::{BigEndian, ByteOrder}; use bytes::{Buf, BufMut, BytesMut}; -use crate::{FromUsize, IsNull, Oid, write_nullable}; +use crate::{CSafeStr, FromUsize, IsNull, Oid, write_nullable}; #[inline] -fn write_body(buf: &mut BytesMut, f: F) -> Result<(), E> +fn write_body(buf: &mut BytesMut, f: F) where - F: FnOnce(&mut BytesMut) -> Result<(), E>, - E: From, + F: FnOnce(&mut BytesMut), { let base = buf.len(); buf.extend_from_slice(&[0; 4]); - f(buf)?; + f(buf); - let size = i32::from_usize(buf.len() - base)?; + let size = + i32::from_usize(buf.len() - base).expect("buffer size should not be larger than i32::MAX"); BigEndian::write_i32(&mut buf[base..], size); - Ok(()) -} - -pub enum BindError { - Conversion(Box), - Serialization(io::Error), -} - -impl From> for BindError { - #[inline] - fn from(e: Box) -> BindError { - BindError::Conversion(e) - } -} - -impl From for BindError { - #[inline] - fn from(e: io::Error) -> BindError { - BindError::Serialization(e) - } } #[inline] pub fn bind( - portal: &str, - statement: &str, + portal: &CSafeStr, + statement: &CSafeStr, formats: I, values: J, mut serializer: F, result_formats: K, buf: &mut BytesMut, -) -> Result<(), BindError> -where +) where I: IntoIterator, J: IntoIterator, - F: FnMut(T, &mut BytesMut) -> Result>, + F: FnMut(T, &mut BytesMut) -> IsNull, K: IntoIterator, { buf.put_u8(b'B'); write_body(buf, |buf| { - write_cstr(portal.as_bytes(), buf)?; - write_cstr(statement.as_bytes(), buf)?; - write_counted( - formats, - |f, buf| { - buf.put_i16(f); - Ok::<_, io::Error>(()) - }, - buf, - )?; + write_cstr(portal, buf); + write_cstr(statement, buf); + write_counted(formats, |f, buf| buf.put_i16(f), buf); write_counted( values, |v, buf| write_nullable(|buf| serializer(v, buf), buf), buf, - )?; - write_counted( - result_formats, - |f, buf| { - buf.put_i16(f); - Ok::<_, io::Error>(()) - }, - buf, - )?; - - Ok(()) + ); + write_counted(result_formats, |f, buf| buf.put_i16(f), buf); }) } #[inline] -fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E> +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) where I: IntoIterator, - F: FnMut(T, &mut BytesMut) -> Result<(), E>, - E: From, + F: FnMut(T, &mut BytesMut), { let base = buf.len(); buf.extend_from_slice(&[0; 2]); let mut count = 0; for item in items { - serializer(item, buf)?; + serializer(item, buf); count += 1; } - let count = i16::from_usize(count)?; + let count = i16::from_usize(count).expect("list should not exceed 32767 items"); BigEndian::write_i16(&mut buf[base..], count); - - Ok(()) } #[inline] @@ -117,17 +76,15 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) { buf.put_i32(80_877_102); buf.put_i32(process_id); buf.put_i32(secret_key); - Ok::<_, io::Error>(()) }) - .unwrap(); } #[inline] -pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { +pub fn close(variant: u8, name: &CSafeStr, buf: &mut BytesMut) { buf.put_u8(b'C'); write_body(buf, |buf| { buf.put_u8(variant); - write_cstr(name.as_bytes(), buf) + write_cstr(name, buf) }) } @@ -162,85 +119,75 @@ where #[inline] pub fn copy_done(buf: &mut BytesMut) { buf.put_u8(b'c'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } #[inline] -pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> { +pub fn copy_fail(message: &CSafeStr, buf: &mut BytesMut) { buf.put_u8(b'f'); - write_body(buf, |buf| write_cstr(message.as_bytes(), buf)) + write_body(buf, |buf| write_cstr(message, buf)) } #[inline] -pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { +pub fn describe(variant: u8, name: &CSafeStr, buf: &mut BytesMut) { buf.put_u8(b'D'); write_body(buf, |buf| { buf.put_u8(variant); - write_cstr(name.as_bytes(), buf) + write_cstr(name, buf) }) } #[inline] -pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> { +pub fn execute(portal: &CSafeStr, max_rows: i32, buf: &mut BytesMut) { buf.put_u8(b'E'); write_body(buf, |buf| { - write_cstr(portal.as_bytes(), buf)?; + write_cstr(portal, buf); buf.put_i32(max_rows); - Ok(()) }) } #[inline] -pub fn parse(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()> +pub fn parse(name: &CSafeStr, query: &CSafeStr, param_types: I, buf: &mut BytesMut) where I: IntoIterator, { buf.put_u8(b'P'); write_body(buf, |buf| { - write_cstr(name.as_bytes(), buf)?; - write_cstr(query.as_bytes(), buf)?; - write_counted( - param_types, - |t, buf| { - buf.put_u32(t); - Ok::<_, io::Error>(()) - }, - buf, - )?; - Ok(()) + write_cstr(name, buf); + write_cstr(query, buf); + write_counted(param_types, |t, buf| buf.put_u32(t), buf); }) } #[inline] -pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> { +pub fn password_message(password: &CSafeStr, buf: &mut BytesMut) { buf.put_u8(b'p'); write_body(buf, |buf| write_cstr(password, buf)) } #[inline] -pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> { +pub fn query(query: &CSafeStr, buf: &mut BytesMut) { buf.put_u8(b'Q'); - write_body(buf, |buf| write_cstr(query.as_bytes(), buf)) + write_body(buf, |buf| write_cstr(query, buf)) } #[inline] -pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> { +pub fn sasl_initial_response(mechanism: &CSafeStr, data: &[u8], buf: &mut BytesMut) { buf.put_u8(b'p'); write_body(buf, |buf| { - write_cstr(mechanism.as_bytes(), buf)?; - let len = i32::from_usize(data.len())?; + write_cstr(mechanism, buf); + let len = + i32::from_usize(data.len()).expect("sasl data should not be larger than i32::MAX"); buf.put_i32(len); buf.put_slice(data); - Ok(()) }) } #[inline] -pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { +pub fn sasl_response(data: &[u8], buf: &mut BytesMut) { buf.put_u8(b'p'); write_body(buf, |buf| { buf.put_slice(data); - Ok(()) }) } @@ -248,19 +195,16 @@ pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { pub fn ssl_request(buf: &mut BytesMut) { write_body(buf, |buf| { buf.put_i32(80_877_103); - Ok::<_, io::Error>(()) - }) - .unwrap(); + }); } #[inline] -pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> { +pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) { write_body(buf, |buf| { // postgres protocol version 3.0(196608) in bigger-endian buf.put_i32(0x00_03_00_00); buf.put_slice(¶meters.params); buf.put_u8(0); - Ok(()) }) } @@ -271,10 +215,7 @@ pub struct StartupMessageParams { impl StartupMessageParams { /// Set parameter's value by its name. - pub fn insert(&mut self, name: &str, value: &str) { - if name.contains('\0') || value.contains('\0') { - panic!("startup parameter name or value contained a null") - } + pub fn insert(&mut self, name: &CSafeStr, value: &CSafeStr) { self.params.put_slice(name.as_bytes()); self.params.put_u8(0); self.params.put_slice(value.as_bytes()); @@ -285,24 +226,17 @@ impl StartupMessageParams { #[inline] pub fn sync(buf: &mut BytesMut) { buf.put_u8(b'S'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } #[inline] pub fn terminate(buf: &mut BytesMut) { buf.put_u8(b'X'); - write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + write_body(buf, |_| {}); } #[inline] -fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { - if s.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", - )); - } - buf.put_slice(s); +fn write_cstr(s: &CSafeStr, buf: &mut BytesMut) { + buf.put_slice(s.as_bytes()); buf.put_u8(0); - Ok(()) } diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs index b6bcabc922..663d9e5166 100644 --- a/libs/proxy/postgres-types2/src/lib.rs +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -13,7 +13,7 @@ use bytes::BytesMut; use fallible_iterator::FallibleIterator; #[doc(inline)] pub use postgres_protocol2::Oid; -use postgres_protocol2::types; +use postgres_protocol2::{IsNull, types}; use crate::type_gen::{Inner, Other}; @@ -32,36 +32,15 @@ macro_rules! accepts { /// All `ToSql` implementations should use this macro. macro_rules! to_sql_checked { () => { - fn to_sql_checked( - &self, - ty: &$crate::Type, - out: &mut $crate::private::BytesMut, - ) -> ::std::result::Result< - $crate::IsNull, - Box, - > { - $crate::__to_sql_checked(self, ty, out) + fn check(&self, ty: &Type) -> ::std::result::Result<(), $crate::WrongType> { + if !::accepts(ty) { + return Err($crate::WrongType::new::(ty.clone())); + } + Ok(()) } }; } -// WARNING: this function is not considered part of this crate's public API. -// It is subject to change at any time. -#[doc(hidden)] -pub fn __to_sql_checked( - v: &T, - ty: &Type, - out: &mut BytesMut, -) -> Result> -where - T: ToSql, -{ - if !T::accepts(ty) { - return Err(Box::new(WrongType::new::(ty.clone()))); - } - v.to_sql(ty, out) -} - // mod pg_lsn; #[doc(hidden)] pub mod private; @@ -369,14 +348,6 @@ macro_rules! simple_from { simple_from!(i8, char_from_sql, CHAR); simple_from!(u32, oid_from_sql, OID); -/// An enum representing the nullability of a Postgres value. -pub enum IsNull { - /// The value is NULL. - Yes, - /// The value is not NULL. - No, -} - /// A trait for types that can be converted into Postgres values. pub trait ToSql: fmt::Debug { /// Converts the value of `self` into the binary format of the specified @@ -388,9 +359,7 @@ pub trait ToSql: fmt::Debug { /// The return value indicates if this value should be represented as /// `NULL`. If this is the case, implementations **must not** write /// anything to `out`. - fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> - where - Self: Sized; + fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> IsNull; /// Determines if a value of this type can be converted to the specified /// Postgres `Type`. @@ -402,11 +371,7 @@ pub trait ToSql: fmt::Debug { /// /// *All* implementations of this method should be generated by the /// `to_sql_checked!()` macro. - fn to_sql_checked( - &self, - ty: &Type, - out: &mut BytesMut, - ) -> Result>; + fn check(&self, ty: &Type) -> Result<(), WrongType>; /// Specify the encode format fn encode_format(&self, _ty: &Type) -> Format { @@ -426,14 +391,14 @@ pub enum Format { } impl ToSql for &str { - fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> IsNull { match *ty { ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), _ => types::text_to_sql(self, w), } - Ok(IsNull::No) + IsNull::No } fn accepts(ty: &Type) -> bool { @@ -457,12 +422,9 @@ impl ToSql for &str { macro_rules! simple_to { ($t:ty, $f:ident, $($expected:ident),+) => { impl ToSql for $t { - fn to_sql(&self, - _: &Type, - w: &mut BytesMut) - -> Result> { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> IsNull { types::$f(*self, w); - Ok(IsNull::No) + IsNull::No } accepts!($($expected),+); diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 186eb07000..68bac0706c 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -219,7 +219,7 @@ impl Client { } let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); + frontend::query(c"ROLLBACK".into(), buf); buf.split().freeze() }); let _ = self diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 978d348741..4cf01c86d6 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -4,6 +4,7 @@ use std::net::IpAddr; use std::time::Duration; use std::{fmt, str}; +use postgres_protocol2::CSafeStr; pub use postgres_protocol2::authentication::sasl::ScramKeys; use postgres_protocol2::message::frontend::StartupMessageParams; use serde::{Deserialize, Serialize}; @@ -162,7 +163,10 @@ impl Config { self.username = true; } - self.server_params.insert(name, value); + self.server_params.insert( + CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"), + CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"), + ); self } diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 20dc538cf2..c8e2592afd 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -6,6 +6,7 @@ use std::task::{Context, Poll}; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready}; +use postgres_protocol2::CSafeStr; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody}; @@ -122,7 +123,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut buf = BytesMut::new(); - frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; + frontend::startup_message(&config.server_params, &mut buf); stream .send(FrontendMessage::Raw(buf.freeze())) @@ -193,7 +194,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut buf = BytesMut::new(); - frontend::password_message(password, &mut buf).map_err(Error::encode)?; + frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf); stream .send(FrontendMessage::Raw(buf.freeze())) @@ -214,10 +215,10 @@ where let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { - match mechanism { - sasl::SCRAM_SHA_256 => has_scram = true, - sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, - _ => {} + if mechanism == sasl::SCRAM_SHA_256 { + has_scram = true; + } else if mechanism == sasl::SCRAM_SHA_256_PLUS { + has_scram_plus = true; } } @@ -256,7 +257,7 @@ where }; let mut buf = BytesMut::new(); - frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; + frontend::sasl_initial_response(mechanism, scram.message(), &mut buf); stream .send(FrontendMessage::Raw(buf.freeze())) .await @@ -275,7 +276,7 @@ where .map_err(|e| Error::authentication(e.into()))?; let mut buf = BytesMut::new(); - frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; + frontend::sasl_response(scram.message(), &mut buf); stream .send(FrontendMessage::Raw(buf.freeze())) .await diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index c8ebba5487..ee952d0ab3 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -123,6 +123,6 @@ pub enum SimpleQueryMessage { fn slice_iter<'a>( s: &'a [&'a (dyn ToSql + Sync)], -) -> impl ExactSizeIterator + 'a { +) -> impl ExactSizeIterator + Clone + 'a { s.iter().map(|s| *s as _) } diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index b27eabcb0e..061f322f20 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -1,3 +1,4 @@ +use std::ffi::CStr; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -5,9 +6,9 @@ use std::sync::Arc; use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{TryStreamExt, pin_mut}; +use postgres_protocol2::CSafeStr; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; -use tracing::debug; use crate::client::{CachedTypeInfo, InnerClient}; use crate::codec::FrontendMessage; @@ -15,7 +16,7 @@ use crate::connection::RequestMessages; use crate::types::{Kind, Oid, Type}; use crate::{Column, Error, Statement, query, slice_iter}; -pub(crate) const TYPEINFO_QUERY: &str = "\ +pub(crate) const TYPEINFO_QUERY: &CStr = c"\ SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid FROM pg_catalog.pg_type t LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid @@ -25,11 +26,11 @@ WHERE t.oid = $1 async fn prepare_typecheck( client: &Arc, - name: &'static str, - query: &str, + name: &'static CStr, + query: &CSafeStr, types: &[Type], ) -> Result { - let buf = encode(client, name, query, types)?; + let buf = encode(client, name.into(), query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { @@ -68,16 +69,21 @@ async fn prepare_typecheck( Ok(Statement::new(client, name, parameters, columns)) } -fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { - if types.is_empty() { - debug!("preparing query {}: {}", name, query); - } else { - debug!("preparing query {} with types {:?}: {}", name, types, query); - } +fn encode( + client: &InnerClient, + name: &CSafeStr, + query: &CSafeStr, + types: &[Type], +) -> Result { + // if types.is_empty() { + // debug!("preparing query {}: {}", name, query); + // } else { + // debug!("preparing query {} with types {:?}: {}", name, types, query); + // } client.with_buf(|buf| { - frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; - frontend::describe(b'S', name, buf).map_err(Error::encode)?; + frontend::parse(name, query, types.iter().map(Type::oid), buf); + frontend::describe(b'S', name, buf); frontend::sync(buf); Ok(buf.split().freeze()) }) @@ -154,8 +160,8 @@ async fn typeinfo_statement( return Ok(stmt.clone()); } - let typeinfo = "neon_proxy_typeinfo"; - let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?; + let typeinfo = c"neon_proxy_typeinfo"; + let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?; typecache.typeinfo = Some(stmt.clone()); Ok(stmt) diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index 106bc69d49..41fbd58b43 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::iter; use std::marker::PhantomPinned; use std::pin::Pin; use std::sync::Arc; @@ -8,25 +8,16 @@ use bytes::{BufMut, Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use futures_util::{Stream, ready}; use pin_project_lite::pin_project; +use postgres_protocol2::CSafeStr; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use postgres_types2::{Format, ToSql, Type}; -use tracing::debug; use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::types::IsNull; use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; -struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); - -impl fmt::Debug for BorrowToSqlParamsDebug<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.0.iter()).finish() - } -} - pub async fn query<'a, I>( client: &InnerClient, statement: Statement, @@ -34,19 +25,9 @@ pub async fn query<'a, I>( ) -> Result where I: IntoIterator, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Clone, { - let buf = if tracing::enabled!(tracing::Level::DEBUG) { - let params = params.into_iter().collect::>(); - debug!( - "executing statement {} with parameters: {:?}", - statement.name(), - BorrowToSqlParamsDebug(params.as_slice()), - ); - encode(client, &statement, params)? - } else { - encode(client, &statement, params)? - }; + let buf = encode(client, &statement, params)?; let responses = start(client, buf).await?; Ok(RowStream { statement, @@ -68,45 +49,45 @@ where I: IntoIterator>, I::IntoIter: ExactSizeIterator, { + let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?; let params = params.into_iter(); + let portal = c"".into(); // unnamed portal + let statement = c"".into(); // unnamed prepared statement + let buf = client.with_buf(|buf| { frontend::parse( - "", // unnamed prepared statement - query, // query to parse - std::iter::empty(), // give no type info + statement, + query, // query to parse + iter::empty(), // give no type info buf, - ) - .map_err(Error::encode)?; - frontend::describe(b'S', "", buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary - match frontend::bind( - "", // empty string selects the unnamed portal - "", // unnamed prepared statement - std::iter::empty(), // all parameters use the default format (text) + ); + frontend::describe(b'S', statement, buf); + + // Bind, pass params as text, retrieve as test + frontend::bind( + portal, + statement, + iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { Some(param) => { buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol2::IsNull::No) + postgres_protocol2::IsNull::No } - None => Ok(postgres_protocol2::IsNull::Yes), + None => postgres_protocol2::IsNull::Yes, }, Some(0), // all text buf, - ) { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - }?; + ); // Execute - frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::execute(portal, 0, buf); // Sync frontend::sync(buf); - Ok(buf.split().freeze()) - })?; + buf.split().freeze() + }); // now read the responses let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; @@ -173,11 +154,13 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result where I: IntoIterator, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Clone, { + let portal = c"".into(); // unnamed portal + client.with_buf(|buf| { - encode_bind(statement, params, "", buf)?; - frontend::execute("", 0, buf).map_err(Error::encode)?; + encode_bind(statement, params, portal, buf)?; + frontend::execute(portal, 0, buf); frontend::sync(buf); Ok(buf.split().freeze()) }) @@ -186,15 +169,15 @@ where pub fn encode_bind<'a, I>( statement: &Statement, params: I, - portal: &str, + portal: &CSafeStr, buf: &mut BytesMut, ) -> Result<(), Error> where I: IntoIterator, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Clone, { let param_types = statement.params(); - let params = params.into_iter(); + let params = iter::zip(params.into_iter(), param_types); assert!( param_types.len() == params.len(), @@ -203,35 +186,24 @@ where params.len() ); - let (param_formats, params): (Vec<_>, Vec<_>) = params - .zip(param_types.iter()) - .map(|(p, ty)| (p.encode_format(ty) as i16, p)) - .unzip(); + // check encodings + for (i, (p, ty)) in params.clone().enumerate() { + p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))? + } - let params = params.into_iter(); + let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16); - let mut error_idx = 0; - let r = frontend::bind( + frontend::bind( portal, statement.name(), param_formats, - params.zip(param_types).enumerate(), - |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), - Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), - Err(e) => { - error_idx = idx; - Err(e) - } - }, + params, + |(param, ty), buf| param.to_sql(ty, buf), Some(1), buf, ); - match r { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - } + + Ok(()) } pin_project! { diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs index 2cf17188cf..8f3fd3c085 100644 --- a/libs/proxy/tokio-postgres2/src/simple_query.rs +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -7,6 +7,7 @@ use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{Stream, ready}; use pin_project_lite::pin_project; +use postgres_protocol2::CSafeStr; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use tracing::debug; @@ -69,8 +70,9 @@ pub async fn batch_execute( } pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { + let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?; client.with_buf(|buf| { - frontend::query(query, buf).map_err(Error::encode)?; + frontend::query(query, buf); Ok(buf.split().freeze()) }) } diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs index e4828db712..d855687583 100644 --- a/libs/proxy/tokio-postgres2/src/statement.rs +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -1,9 +1,10 @@ +use std::ffi::CStr; use std::fmt; use std::sync::{Arc, Weak}; -use postgres_protocol2::Oid; use postgres_protocol2::message::backend::Field; use postgres_protocol2::message::frontend; +use postgres_protocol2::{CSafeStr, Oid}; use crate::client::InnerClient; use crate::codec::FrontendMessage; @@ -12,7 +13,7 @@ use crate::types::Type; struct StatementInner { client: Weak, - name: &'static str, + name: &'static CStr, params: Vec, columns: Vec, } @@ -21,7 +22,7 @@ impl Drop for StatementInner { fn drop(&mut self) { if let Some(client) = self.client.upgrade() { let buf = client.with_buf(|buf| { - frontend::close(b'S', self.name, buf).unwrap(); + frontend::close(b'S', self.name.into(), buf); frontend::sync(buf); buf.split().freeze() }); @@ -39,7 +40,7 @@ pub struct Statement(Arc); impl Statement { pub(crate) fn new( inner: &Arc, - name: &'static str, + name: &'static CStr, params: Vec, columns: Vec, ) -> Statement { @@ -54,14 +55,14 @@ impl Statement { pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { Statement(Arc::new(StatementInner { client: Weak::new(), - name: "", + name: c"", params, columns, })) } - pub(crate) fn name(&self) -> &str { - self.0.name + pub(crate) fn name(&self) -> &CSafeStr { + self.0.name.into() } /// Returns the expected types of the statement's parameters. diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index f32603470f..884d2b327e 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -21,7 +21,7 @@ impl Drop for Transaction<'_> { } let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); + frontend::query(c"ROLLBACK".into(), buf); buf.split().freeze() }); let _ = self diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 83feed5094..07d9418c81 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -536,7 +536,8 @@ mod tests { use control_plane::AuthSecret; use fallible_iterator::FallibleIterator; use once_cell::sync::Lazy; - use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256}; + use postgres_protocol::CSafeStr; + use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256}; use postgres_protocol::message::backend::Message as PgMessage; use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; @@ -714,15 +715,15 @@ mod tests { // server should offer scram match read_message(&mut client, &mut read).await { PgMessage::AuthenticationSasl(a) => { - let options: Vec<&str> = a.mechanisms().collect().unwrap(); - assert_eq!(options, ["SCRAM-SHA-256"]); + let options: Vec<&CSafeStr> = a.mechanisms().collect().unwrap(); + assert_eq!(options, [SCRAM_SHA_256]); } _ => panic!("wrong message"), } // client sends client-first-message let mut write = BytesMut::new(); - frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap(); + frontend::sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut write); client.write_all(&write).await.unwrap(); // server response with server-first-message @@ -735,7 +736,7 @@ mod tests { // client response with client-final-message write.clear(); - frontend::sasl_response(scram.message(), &mut write).unwrap(); + frontend::sasl_response(scram.message(), &mut write); client.write_all(&write).await.unwrap(); // server response with server-final-message @@ -800,7 +801,7 @@ mod tests { // client responds with password write.clear(); - frontend::password_message(b"my-secret-password", &mut write).unwrap(); + frontend::password_message(c"my-secret-password".into(), &mut write); client.write_all(&write).await.unwrap(); }); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( @@ -853,8 +854,10 @@ mod tests { // client responds with password let mut write = BytesMut::new(); - frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write) - .unwrap(); + frontend::password_message( + c"endpoint=my-endpoint;my-secret-password".into(), + &mut write, + ); client.write_all(&write).await.unwrap(); }); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..66c7c0f6ad 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -3,7 +3,6 @@ use std::io; use std::sync::Arc; -use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -174,8 +173,10 @@ impl AuthFlow<'_, S, Scram<'_>> { } match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), + scram::SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + scram::SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus) + } _ => {} } diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..92c42b1d2b 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -9,7 +9,7 @@ use std::fmt::Debug; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; -use postgres_protocol::message::frontend; +use postgres_protocol::{authentication::sasl::SCRAM_SHA_256, message::frontend}; use tokio::io::{AsyncReadExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; @@ -60,8 +60,7 @@ async fn proxy_mitm( params: startup.params.into(), }, &mut buf, - ) - .unwrap(); + ); end_server.send(buf.freeze()).await.unwrap(); // proxy messages between end_client and end_server @@ -90,7 +89,7 @@ async fn proxy_mitm( new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap()); let mut buf = BytesMut::new(); - frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap(); + frontend::sasl_initial_response(SCRAM_SHA_256, &new_message, &mut buf); end_server.send(buf.freeze()).await.unwrap(); continue; diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 4f764c6087..db74870e9b 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -21,8 +21,8 @@ pub(crate) use key::ScramKey; pub(crate) use secret::ServerSecret; use sha2::{Digest, Sha256}; -const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; -const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; +pub(crate) const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +pub(crate) const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; /// A list of supported SCRAM methods. pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];