mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-25 17:10:38 +00:00
simplify error handling for query encoding
This commit is contained in:
@@ -219,7 +219,7 @@ impl Client {
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
frontend::query(c"ROLLBACK".into(), buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, str};
|
||||
|
||||
use postgres_protocol2::CSafeStr;
|
||||
pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
use postgres_protocol2::message::frontend::StartupMessageParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -162,7 +163,10 @@ impl Config {
|
||||
self.username = true;
|
||||
}
|
||||
|
||||
self.server_params.insert(name, value);
|
||||
self.server_params.insert(
|
||||
CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"),
|
||||
CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::task::{Context, Poll};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
@@ -122,7 +123,7 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
|
||||
frontend::startup_message(&config.server_params, &mut buf);
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
@@ -193,7 +194,7 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
|
||||
frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf);
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
@@ -214,10 +215,10 @@ where
|
||||
let mut has_scram_plus = false;
|
||||
let mut mechanisms = body.mechanisms();
|
||||
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
|
||||
match mechanism {
|
||||
sasl::SCRAM_SHA_256 => has_scram = true,
|
||||
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
|
||||
_ => {}
|
||||
if mechanism == sasl::SCRAM_SHA_256 {
|
||||
has_scram = true;
|
||||
} else if mechanism == sasl::SCRAM_SHA_256_PLUS {
|
||||
has_scram_plus = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +257,7 @@ where
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf);
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
@@ -275,7 +276,7 @@ where
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
frontend::sasl_response(scram.message(), &mut buf);
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
|
||||
@@ -123,6 +123,6 @@ pub enum SimpleQueryMessage {
|
||||
|
||||
fn slice_iter<'a>(
|
||||
s: &'a [&'a (dyn ToSql + Sync)],
|
||||
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
|
||||
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + Clone + 'a {
|
||||
s.iter().map(|s| *s as _)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::ffi::CStr;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -5,9 +6,9 @@ use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{TryStreamExt, pin_mut};
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::{CachedTypeInfo, InnerClient};
|
||||
use crate::codec::FrontendMessage;
|
||||
@@ -15,7 +16,7 @@ use crate::connection::RequestMessages;
|
||||
use crate::types::{Kind, Oid, Type};
|
||||
use crate::{Column, Error, Statement, query, slice_iter};
|
||||
|
||||
pub(crate) const TYPEINFO_QUERY: &str = "\
|
||||
pub(crate) const TYPEINFO_QUERY: &CStr = c"\
|
||||
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||
@@ -25,11 +26,11 @@ WHERE t.oid = $1
|
||||
|
||||
async fn prepare_typecheck(
|
||||
client: &Arc<InnerClient>,
|
||||
name: &'static str,
|
||||
query: &str,
|
||||
name: &'static CStr,
|
||||
query: &CSafeStr,
|
||||
types: &[Type],
|
||||
) -> Result<Statement, Error> {
|
||||
let buf = encode(client, name, query, types)?;
|
||||
let buf = encode(client, name.into(), query, types)?;
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
@@ -68,16 +69,21 @@ async fn prepare_typecheck(
|
||||
Ok(Statement::new(client, name, parameters, columns))
|
||||
}
|
||||
|
||||
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
|
||||
if types.is_empty() {
|
||||
debug!("preparing query {}: {}", name, query);
|
||||
} else {
|
||||
debug!("preparing query {} with types {:?}: {}", name, types, query);
|
||||
}
|
||||
fn encode(
|
||||
client: &InnerClient,
|
||||
name: &CSafeStr,
|
||||
query: &CSafeStr,
|
||||
types: &[Type],
|
||||
) -> Result<Bytes, Error> {
|
||||
// if types.is_empty() {
|
||||
// debug!("preparing query {}: {}", name, query);
|
||||
// } else {
|
||||
// debug!("preparing query {} with types {:?}: {}", name, types, query);
|
||||
// }
|
||||
|
||||
client.with_buf(|buf| {
|
||||
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
|
||||
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
|
||||
frontend::parse(name, query, types.iter().map(Type::oid), buf);
|
||||
frontend::describe(b'S', name, buf);
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
@@ -154,8 +160,8 @@ async fn typeinfo_statement(
|
||||
return Ok(stmt.clone());
|
||||
}
|
||||
|
||||
let typeinfo = "neon_proxy_typeinfo";
|
||||
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?;
|
||||
let typeinfo = c"neon_proxy_typeinfo";
|
||||
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?;
|
||||
|
||||
typecache.typeinfo = Some(stmt.clone());
|
||||
Ok(stmt)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::fmt;
|
||||
use std::iter;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -8,25 +8,16 @@ use bytes::{BufMut, Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Stream, ready};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use postgres_types2::{Format, ToSql, Type};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::IsNull;
|
||||
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
|
||||
|
||||
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
|
||||
|
||||
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn query<'a, I>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
@@ -34,19 +25,9 @@ pub async fn query<'a, I>(
|
||||
) -> Result<RowStream, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let buf = if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
let params = params.into_iter().collect::<Vec<_>>();
|
||||
debug!(
|
||||
"executing statement {} with parameters: {:?}",
|
||||
statement.name(),
|
||||
BorrowToSqlParamsDebug(params.as_slice()),
|
||||
);
|
||||
encode(client, &statement, params)?
|
||||
} else {
|
||||
encode(client, &statement, params)?
|
||||
};
|
||||
let buf = encode(client, &statement, params)?;
|
||||
let responses = start(client, buf).await?;
|
||||
Ok(RowStream {
|
||||
statement,
|
||||
@@ -68,45 +49,45 @@ where
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
|
||||
let params = params.into_iter();
|
||||
|
||||
let portal = c"".into(); // unnamed portal
|
||||
let statement = c"".into(); // unnamed prepared statement
|
||||
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::parse(
|
||||
"", // unnamed prepared statement
|
||||
query, // query to parse
|
||||
std::iter::empty(), // give no type info
|
||||
statement,
|
||||
query, // query to parse
|
||||
iter::empty(), // give no type info
|
||||
buf,
|
||||
)
|
||||
.map_err(Error::encode)?;
|
||||
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
|
||||
// Bind, pass params as text, retrieve as binary
|
||||
match frontend::bind(
|
||||
"", // empty string selects the unnamed portal
|
||||
"", // unnamed prepared statement
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
);
|
||||
frontend::describe(b'S', statement, buf);
|
||||
|
||||
// Bind, pass params as text, retrieve as test
|
||||
frontend::bind(
|
||||
portal,
|
||||
statement,
|
||||
iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol2::IsNull::No)
|
||||
postgres_protocol2::IsNull::No
|
||||
}
|
||||
None => Ok(postgres_protocol2::IsNull::Yes),
|
||||
None => postgres_protocol2::IsNull::Yes,
|
||||
},
|
||||
Some(0), // all text
|
||||
buf,
|
||||
) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}?;
|
||||
);
|
||||
|
||||
// Execute
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
frontend::execute(portal, 0, buf);
|
||||
// Sync
|
||||
frontend::sync(buf);
|
||||
|
||||
Ok(buf.split().freeze())
|
||||
})?;
|
||||
buf.split().freeze()
|
||||
});
|
||||
|
||||
// now read the responses
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
@@ -173,11 +154,13 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
|
||||
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let portal = c"".into(); // unnamed portal
|
||||
|
||||
client.with_buf(|buf| {
|
||||
encode_bind(statement, params, "", buf)?;
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
encode_bind(statement, params, portal, buf)?;
|
||||
frontend::execute(portal, 0, buf);
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
@@ -186,15 +169,15 @@ where
|
||||
pub fn encode_bind<'a, I>(
|
||||
statement: &Statement,
|
||||
params: I,
|
||||
portal: &str,
|
||||
portal: &CSafeStr,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
I::IntoIter: ExactSizeIterator + Clone,
|
||||
{
|
||||
let param_types = statement.params();
|
||||
let params = params.into_iter();
|
||||
let params = iter::zip(params.into_iter(), param_types);
|
||||
|
||||
assert!(
|
||||
param_types.len() == params.len(),
|
||||
@@ -203,35 +186,24 @@ where
|
||||
params.len()
|
||||
);
|
||||
|
||||
let (param_formats, params): (Vec<_>, Vec<_>) = params
|
||||
.zip(param_types.iter())
|
||||
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
|
||||
.unzip();
|
||||
// check encodings
|
||||
for (i, (p, ty)) in params.clone().enumerate() {
|
||||
p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))?
|
||||
}
|
||||
|
||||
let params = params.into_iter();
|
||||
let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16);
|
||||
|
||||
let mut error_idx = 0;
|
||||
let r = frontend::bind(
|
||||
frontend::bind(
|
||||
portal,
|
||||
statement.name(),
|
||||
param_formats,
|
||||
params.zip(param_types).enumerate(),
|
||||
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
|
||||
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
|
||||
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
|
||||
Err(e) => {
|
||||
error_idx = idx;
|
||||
Err(e)
|
||||
}
|
||||
},
|
||||
params,
|
||||
|(param, ty), buf| param.to_sql(ty, buf),
|
||||
Some(1),
|
||||
buf,
|
||||
);
|
||||
match r {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
|
||||
@@ -7,6 +7,7 @@ use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Stream, ready};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::CSafeStr;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tracing::debug;
|
||||
@@ -69,8 +70,9 @@ pub async fn batch_execute(
|
||||
}
|
||||
|
||||
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
|
||||
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
|
||||
client.with_buf(|buf| {
|
||||
frontend::query(query, buf).map_err(Error::encode)?;
|
||||
frontend::query(query, buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::ffi::CStr;
|
||||
use std::fmt;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use postgres_protocol2::Oid;
|
||||
use postgres_protocol2::message::backend::Field;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use postgres_protocol2::{CSafeStr, Oid};
|
||||
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
@@ -12,7 +13,7 @@ use crate::types::Type;
|
||||
|
||||
struct StatementInner {
|
||||
client: Weak<InnerClient>,
|
||||
name: &'static str,
|
||||
name: &'static CStr,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
}
|
||||
@@ -21,7 +22,7 @@ impl Drop for StatementInner {
|
||||
fn drop(&mut self) {
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::close(b'S', self.name, buf).unwrap();
|
||||
frontend::close(b'S', self.name.into(), buf);
|
||||
frontend::sync(buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
@@ -39,7 +40,7 @@ pub struct Statement(Arc<StatementInner>);
|
||||
impl Statement {
|
||||
pub(crate) fn new(
|
||||
inner: &Arc<InnerClient>,
|
||||
name: &'static str,
|
||||
name: &'static CStr,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
) -> Statement {
|
||||
@@ -54,14 +55,14 @@ impl Statement {
|
||||
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
|
||||
Statement(Arc::new(StatementInner {
|
||||
client: Weak::new(),
|
||||
name: "<anonymous>",
|
||||
name: c"<anonymous>",
|
||||
params,
|
||||
columns,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
self.0.name
|
||||
pub(crate) fn name(&self) -> &CSafeStr {
|
||||
self.0.name.into()
|
||||
}
|
||||
|
||||
/// Returns the expected types of the statement's parameters.
|
||||
|
||||
@@ -21,7 +21,7 @@ impl Drop for Transaction<'_> {
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
frontend::query(c"ROLLBACK".into(), buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
|
||||
Reference in New Issue
Block a user