simplify error handling for query encoding

This commit is contained in:
Conrad Ludgate
2025-05-21 13:37:57 +01:00
parent f3c9d0adf4
commit 13d41b51a2
18 changed files with 246 additions and 313 deletions

View File

@@ -219,7 +219,7 @@ impl Client {
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self

View File

@@ -4,6 +4,7 @@ use std::net::IpAddr;
use std::time::Duration;
use std::{fmt, str};
use postgres_protocol2::CSafeStr;
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use postgres_protocol2::message::frontend::StartupMessageParams;
use serde::{Deserialize, Serialize};
@@ -162,7 +163,10 @@ impl Config {
self.username = true;
}
self.server_params.insert(name, value);
self.server_params.insert(
CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"),
CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"),
);
self
}

View File

@@ -6,6 +6,7 @@ use std::task::{Context, Poll};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
@@ -122,7 +123,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
frontend::startup_message(&config.server_params, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -193,7 +194,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -214,10 +215,10 @@ where
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
if mechanism == sasl::SCRAM_SHA_256 {
has_scram = true;
} else if mechanism == sasl::SCRAM_SHA_256_PLUS {
has_scram_plus = true;
}
}
@@ -256,7 +257,7 @@ where
};
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
@@ -275,7 +276,7 @@ where
.map_err(|e| Error::authentication(e.into()))?;
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_response(scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await

View File

@@ -123,6 +123,6 @@ pub enum SimpleQueryMessage {
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + Clone + 'a {
s.iter().map(|s| *s as _)
}

View File

@@ -1,3 +1,4 @@
use std::ffi::CStr;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
@@ -5,9 +6,9 @@ use std::sync::Arc;
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{TryStreamExt, pin_mut};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tracing::debug;
use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage;
@@ -15,7 +16,7 @@ use crate::connection::RequestMessages;
use crate::types::{Kind, Oid, Type};
use crate::{Column, Error, Statement, query, slice_iter};
pub(crate) const TYPEINFO_QUERY: &str = "\
pub(crate) const TYPEINFO_QUERY: &CStr = c"\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
@@ -25,11 +26,11 @@ WHERE t.oid = $1
async fn prepare_typecheck(
client: &Arc<InnerClient>,
name: &'static str,
query: &str,
name: &'static CStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Statement, Error> {
let buf = encode(client, name, query, types)?;
let buf = encode(client, name.into(), query, types)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
@@ -68,16 +69,21 @@ async fn prepare_typecheck(
Ok(Statement::new(client, name, parameters, columns))
}
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
debug!("preparing query {} with types {:?}: {}", name, types, query);
}
fn encode(
client: &InnerClient,
name: &CSafeStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Bytes, Error> {
// if types.is_empty() {
// debug!("preparing query {}: {}", name, query);
// } else {
// debug!("preparing query {} with types {:?}: {}", name, types, query);
// }
client.with_buf(|buf| {
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
frontend::parse(name, query, types.iter().map(Type::oid), buf);
frontend::describe(b'S', name, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
@@ -154,8 +160,8 @@ async fn typeinfo_statement(
return Ok(stmt.clone());
}
let typeinfo = "neon_proxy_typeinfo";
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?;
let typeinfo = c"neon_proxy_typeinfo";
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?;
typecache.typeinfo = Some(stmt.clone());
Ok(stmt)

View File

@@ -1,4 +1,4 @@
use std::fmt;
use std::iter;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
@@ -8,25 +8,16 @@ use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::{Format, ToSql, Type};
use tracing::debug;
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::IsNull;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
@@ -34,19 +25,9 @@ pub async fn query<'a, I>(
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let buf = if tracing::enabled!(tracing::Level::DEBUG) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let buf = encode(client, &statement, params)?;
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
@@ -68,45 +49,45 @@ where
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
let params = params.into_iter();
let portal = c"".into(); // unnamed portal
let statement = c"".into(); // unnamed prepared statement
let buf = client.with_buf(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
std::iter::empty(), // give no type info
statement,
query, // query to parse
iter::empty(), // give no type info
buf,
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
);
frontend::describe(b'S', statement, buf);
// Bind, pass params as text, retrieve as test
frontend::bind(
portal,
statement,
iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol2::IsNull::No)
postgres_protocol2::IsNull::No
}
None => Ok(postgres_protocol2::IsNull::Yes),
None => postgres_protocol2::IsNull::Yes,
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}?;
);
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::execute(portal, 0, buf);
// Sync
frontend::sync(buf);
Ok(buf.split().freeze())
})?;
buf.split().freeze()
});
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
@@ -173,11 +154,13 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let portal = c"".into(); // unnamed portal
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
encode_bind(statement, params, portal, buf)?;
frontend::execute(portal, 0, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
@@ -186,15 +169,15 @@ where
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &str,
portal: &CSafeStr,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
I::IntoIter: ExactSizeIterator + Clone,
{
let param_types = statement.params();
let params = params.into_iter();
let params = iter::zip(params.into_iter(), param_types);
assert!(
param_types.len() == params.len(),
@@ -203,35 +186,24 @@ where
params.len()
);
let (param_formats, params): (Vec<_>, Vec<_>) = params
.zip(param_types.iter())
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
.unzip();
// check encodings
for (i, (p, ty)) in params.clone().enumerate() {
p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))?
}
let params = params.into_iter();
let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16);
let mut error_idx = 0;
let r = frontend::bind(
frontend::bind(
portal,
statement.name(),
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
Err(e) => {
error_idx = idx;
Err(e)
}
},
params,
|(param, ty), buf| param.to_sql(ty, buf),
Some(1),
buf,
);
match r {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}
Ok(())
}
pin_project! {

View File

@@ -7,6 +7,7 @@ use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tracing::debug;
@@ -69,8 +70,9 @@ pub async fn batch_execute(
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
frontend::query(query, buf);
Ok(buf.split().freeze())
})
}

View File

@@ -1,9 +1,10 @@
use std::ffi::CStr;
use std::fmt;
use std::sync::{Arc, Weak};
use postgres_protocol2::Oid;
use postgres_protocol2::message::backend::Field;
use postgres_protocol2::message::frontend;
use postgres_protocol2::{CSafeStr, Oid};
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
@@ -12,7 +13,7 @@ use crate::types::Type;
struct StatementInner {
client: Weak<InnerClient>,
name: &'static str,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
}
@@ -21,7 +22,7 @@ impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', self.name, buf).unwrap();
frontend::close(b'S', self.name.into(), buf);
frontend::sync(buf);
buf.split().freeze()
});
@@ -39,7 +40,7 @@ pub struct Statement(Arc<StatementInner>);
impl Statement {
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: &'static str,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
@@ -54,14 +55,14 @@ impl Statement {
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: "<anonymous>",
name: c"<anonymous>",
params,
columns,
}))
}
pub(crate) fn name(&self) -> &str {
self.0.name
pub(crate) fn name(&self) -> &CSafeStr {
self.0.name.into()
}
/// Returns the expected types of the statement's parameters.

View File

@@ -21,7 +21,7 @@ impl Drop for Transaction<'_> {
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self