This commit is contained in:
Conrad Ludgate
2023-07-27 13:12:16 +01:00
parent 4fb5cdbdb8
commit 17627e8023
12 changed files with 451 additions and 376 deletions

1
Cargo.lock generated
View File

@@ -3074,6 +3074,7 @@ dependencies = [
"thiserror",
"tls-listener",
"tokio",
"tokio-native-tls",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.23.4",

View File

@@ -65,6 +65,7 @@ webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
tokio-native-tls = "0.3.1"
workspace_hack.workspace = true
tokio-util.workspace = true

View File

@@ -26,18 +26,14 @@ use tokio_postgres::RowStream;
use tokio_postgres::Statement;
use url::Url;
use crate::http::sql_over_http::codec::FrontendMessage;
use crate::http::sql_over_http::connection::RequestMessages;
use crate::pg_client;
use crate::pg_client::codec::FrontendMessage;
use crate::pg_client::connection;
use crate::pg_client::connection::RequestMessages;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
mod codec;
mod connection;
mod error;
// mod prepare;
// mod pg_type;
#[derive(serde::Deserialize)]
struct QueryData {
query: String,
@@ -374,75 +370,20 @@ async fn query_raw_txt<'a, St, T>(
conn: &mut connection::Connection<St, T>,
query: String,
params: Vec<Option<String>>,
) -> Result<Vec<Column>, error::Error>
) -> Result<Vec<Column>, pg_client::error::Error>
where
St: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
let params_len = params.len();
let params = params.into_iter();
{
let buf = &mut conn.buf;
// Parse, anonymous portal
frontend::parse("", query.as_str(), std::iter::empty(), buf)
.map_err(error::Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
"", // empty string selects the unnamed portal
"", // empty string selects the unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_bytes());
Ok(postgres_protocol::IsNull::No)
}
None => Ok(postgres_protocol::IsNull::Yes),
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(error::Error::encode(
std::io::Error::new(ErrorKind::Other, e),
)),
Err(frontend::BindError::Serialization(e)) => Err(error::Error::encode(e)),
}?;
// Describe portal to typecast results
frontend::describe(b'P', "", buf).map_err(error::Error::encode)?;
// Execute
frontend::execute("", 0, buf).map_err(error::Error::encode)?;
// Sync
frontend::sync(buf);
}
conn.send().await?;
// now read the responses
match conn.next_message().await? {
Message::ParseComplete => {}
_ => return Err(error::Error::unexpected_message()),
}
match conn.next_message().await? {
Message::BindComplete => {}
_ => return Err(error::Error::unexpected_message()),
}
let row_description = match conn.next_message().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(error::Error::unexpected_message()),
};
conn.prepare_and_execute("", "", query.as_str(), params)?;
conn.sync().await?;
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(error::Error::parse)? {
if let Some((desc, rows)) = conn.stream_query_results().await? {
let mut it = desc.fields();
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
let type_ = Type::from_oid(field.type_oid());
// let column = Column::new(field.name().to_string(), type_, field);
columns.push(Column {
@@ -452,8 +393,6 @@ where
}
}
// let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
Ok(columns)
}
@@ -515,7 +454,7 @@ fn pg_text_row_to_json2(
let pg_value = range
.map(|r| {
std::str::from_utf8(&row.buffer()[r])
.map_err(|e| error::Error::from_sql(e.into(), i))
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
})
.transpose()?;
// let pg_value = row.as_text(i)?;

View File

@@ -1,108 +0,0 @@
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
// CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub enum BackendMessage {
Normal {
messages: BackendMessages,
request_complete: bool,
},
Async(backend::Message),
}
pub struct BackendMessages(pub BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {
BackendMessages(BytesMut::new())
}
}
impl FallibleIterator for BackendMessages {
type Item = backend::Message;
type Error = io::Error;
fn next(&mut self) -> io::Result<Option<backend::Message>> {
backend::Message::parse(&mut self.0)
}
}
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
// FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = BackendMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
let mut idx = 0;
let mut request_complete = false;
while let Some(header) = backend::Header::parse(&src[idx..])? {
let len = header.len() as usize + 1;
if src[idx..].len() < len {
break;
}
if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}
match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
| backend::PARAMETER_STATUS_TAG => {
if idx == 0 {
let message = backend::Message::parse(src)?.unwrap();
return Ok(Some(BackendMessage::Async(message)));
} else {
break;
}
}
_ => {}
}
idx += len;
if header.tag() == backend::READY_FOR_QUERY_TAG {
request_complete = true;
break;
}
}
if idx == 0 {
Ok(None)
} else {
Ok(Some(BackendMessage::Normal {
messages: BackendMessages(src.split_to(idx)),
request_complete,
}))
}
}
}

View File

@@ -1,194 +0,0 @@
use super::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use super::error::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::SinkExt;
use futures::{Sink, StreamExt};
use postgres_protocol::message::backend::Message;
use std::collections::{HashMap, VecDeque};
use std::future::poll_fn;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
use tokio_util::codec::Framed;
use tracing::trace;
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: mpsc::Sender<BackendMessages>,
}
// #[derive(PartialEq, Debug)]
// enum State {
// Active,
// Terminating,
// Closing,
// }
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
/// server, and should generally be spawned off onto an executor to run in the background.
///
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
// receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<(BackendMessages, bool)>,
pub buf: BytesMut,
// responses: VecDeque<Response>,
// state: State,
}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<(BackendMessages, bool)>,
parameters: HashMap<String, String>,
// receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
// receiver,
pending_request: None,
pending_responses,
buf: BytesMut::new(),
// responses: VecDeque::new(),
// state: State::Active,
}
}
pub async fn send(&mut self) -> Result<(), Error> {
poll_fn(|cx| self.poll_send(cx)).await?;
let request = FrontendMessage::Raw(self.buf.split().freeze());
self.stream.start_send_unpin(request).map_err(Error::io)?;
poll_fn(|cx| self.poll_flush(cx)).await
}
pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> {
match self.pending_responses.pop_front() {
Some((a, b)) => Ok((a, b)),
None => poll_fn(|cx| self.poll_read(cx)).await,
}
}
pub async fn next_message(&mut self) -> Result<Message, Error> {
loop {
let (mut messages, complete) = self.next_response().await?;
if let Some(message) = messages.next().map_err(Error::parse)? {
self.pending_responses.push_front((messages, complete));
break Ok(message);
}
if complete {
break Err(Error::unexpected_message());
}
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
self.stream
.poll_next_unpin(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<(BackendMessages, bool), Error>> {
loop {
let message = match ready!(self.poll_response(cx)?) {
Some(message) => message,
None => return Poll::Ready(Err(Error::closed())),
};
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
// TODO: log this
// let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
// return Ok(Some(AsyncMessage::Notice(error)));
continue;
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
// TODO: log this
// let notification = Notification {
// process_id: body.process_id(),
// channel: body.channel().map_err(Error::parse)?.to_string(),
// payload: body.message().map_err(Error::parse)?.to_string(),
// };
// return Ok(Some(AsyncMessage::Notification(notification)));
continue;
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal {
messages,
request_complete,
} => return Poll::Ready(Ok((messages, request_complete))),
};
}
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_shutdown: complete");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_shutdown: waiting on socket");
Poll::Pending
}
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_ready_unpin(cx).map_err(Error::io)
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_flush_unpin(cx).map_err(Error::io)
}
}

View File

@@ -22,6 +22,7 @@ pub mod scram;
pub mod stream;
pub mod url;
pub mod waiters;
pub mod pg_client;
/// Handle unix signals appropriately.
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {

View File

@@ -0,0 +1,43 @@
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{self, Message};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub struct FrontendMessage(pub Bytes);
pub struct BackendMessages(pub BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {
BackendMessages(BytesMut::new())
}
}
impl FallibleIterator for BackendMessages {
type Item = backend::Message;
type Error = io::Error;
fn next(&mut self) -> io::Result<Option<backend::Message>> {
backend::Message::parse(&mut self.0)
}
}
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item.0);
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = Message;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
Message::parse(src)
}
}

View File

@@ -0,0 +1,369 @@
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
use super::error::Error;
use bytes::{BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{Sink, StreamExt};
use futures::{SinkExt, Stream};
use postgres_protocol::authentication;
use postgres_protocol::message::backend::{
BackendKeyDataBody, DataRowBody, Message, ReadyForQueryBody, RowDescriptionBody,
};
use postgres_protocol::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::future::poll_fn;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
use tokio_util::codec::Framed;
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: mpsc::Sender<BackendMessages>,
}
/// A connection to a PostgreSQL database.
pub struct RawConnection<S, T> {
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<Message>,
pub buf: BytesMut,
}
// enum MaybeTlsStream {
// NoTls(TcpStream),
// Tls(TlsStream<TcpStream>),
// }
// impl Unpin for MaybeTlsStream {}
// impl AsyncRead for MaybeTlsStream {
// fn poll_read(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &mut tokio::io::ReadBuf<'_>,
// ) -> Poll<std::io::Result<()>> {
// match self.get_mut() {
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_read(cx, buf),
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
// }
// }
// }
// impl AsyncWrite for MaybeTlsStream {
// fn poll_write(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &[u8],
// ) -> Poll<Result<usize, std::io::Error>> {
// match self.get_mut() {
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write(cx, buf),
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
// }
// }
// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
// match self.get_mut() {
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_flush(cx),
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx),
// }
// }
// fn poll_shutdown(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// ) -> Poll<Result<(), std::io::Error>> {
// match self.get_mut() {
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_shutdown(cx),
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_shutdown(cx),
// }
// }
// fn poll_write_vectored(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// bufs: &[std::io::IoSlice<'_>],
// ) -> Poll<Result<usize, std::io::Error>> {
// match self.get_mut() {
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write_vectored(cx, bufs),
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write_vectored(cx, bufs),
// }
// }
// fn is_write_vectored(&self) -> bool {
// match self {
// MaybeTlsStream::NoTls(no_tls) => no_tls.is_write_vectored(),
// MaybeTlsStream::Tls(tls) => tls.is_write_vectored(),
// }
// }
// }
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
// pub(crate) async fn connect(
// mut stream: TcpStream,
// tls_domain: Option<&str>,
// ) -> Result<RawConnection<S, T>, Error> {
// let mut buf = BytesMut::new();
// let stream = if let Some(tls_domain) = tls_domain {
// frontend::ssl_request(&mut buf);
// stream
// .write_all_buf(&mut buf.split().freeze())
// .await
// .unwrap();
// let bit = stream.read_u8().await.map_err(Error::io)?;
// if bit != b'S' {
// return Err(Error::closed());
// }
// let tls = native_tls::TlsConnector::new().map_err(Error::tls)?;
// let tls = TlsConnector::from(tls)
// .connect(tls_domain, stream)
// .await
// .map_err(Error::tls)?;
// MaybeTlsStream::Tls(tls)
// } else {
// MaybeTlsStream::Raw(stream)
// };
// Ok(RawConnection::new(Framed::new(stream, PostgresCodec), buf))
// }
pub fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BytesMut,
) -> RawConnection<S, T> {
RawConnection {
stream,
pending_responses: VecDeque::new(),
buf,
}
}
pub async fn send(&mut self) -> Result<(), Error> {
poll_fn(|cx| self.poll_send(cx)).await?;
let request = FrontendMessage(self.buf.split().freeze());
self.stream.start_send_unpin(request).map_err(Error::io)?;
poll_fn(|cx| self.poll_flush(cx)).await
}
pub async fn next_message(&mut self) -> Result<Message, Error> {
match self.pending_responses.pop_front() {
Some(message) => Ok(message),
None => poll_fn(|cx| self.poll_read(cx)).await,
}
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
Some(message) => message,
None => return Poll::Ready(Err(Error::closed())),
};
Poll::Ready(Ok(message))
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
}
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_ready_unpin(cx).map_err(Error::io)
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_flush_unpin(cx).map_err(Error::io)
}
}
pub struct Connection<S, T> {
raw: RawConnection<S, T>,
key: BackendKeyDataBody,
}
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
// pub async fn auth_sasl_scram<'a, I>(
// mut raw: RawConnection<S, T>,
// params: I,
// password: &[u8],
// ) -> Result<Self, Error>
// where
// I: IntoIterator<Item = (&'a str, &'a str)>,
// {
// // send a startup message
// frontend::startup_message(params, &mut raw.buf).unwrap();
// raw.send().await?;
// // expect sasl authentication message
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
// // expect support for SCRAM_SHA_256
// if body
// .mechanisms()
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
// .is_none()
// {
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
// }
// // initiate SCRAM_SHA_256 authentication without channel binding
// let auth = authentication::sasl::ChannelBinding::unrequested();
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
// frontend::sasl_initial_response(
// authentication::sasl::SCRAM_SHA_256,
// scram.message(),
// &mut raw.buf,
// )
// .unwrap();
// raw.send().await?;
// // expect sasl continue
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
// scram.update(b.data()).unwrap();
// // continue sasl
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
// raw.send().await?;
// // expect sasl final
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
// scram.finish(b.data()).unwrap();
// // expect auth ok
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
// // expect connection accepted
// let key = loop {
// match raw.next_message().await? {
// Message::BackendKeyData(key) => break key,
// Message::ParameterStatus(_) => {}
// _ => return Err(Error::expecting("backend ready")),
// }
// };
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
// // assert_eq!(b.status(), b'I');
// Ok(Self { raw, key })
// }
pub fn prepare_and_execute(
&mut self,
portal: &str,
name: &str,
query: &str,
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
) -> std::io::Result<()> {
self.prepare(name, query)?;
self.execute(portal, name, params)
}
pub fn prepare(&mut self, name: &str, query: &str) -> std::io::Result<()> {
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)
}
pub fn execute(
&mut self,
portal: &str,
name: &str,
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
) -> std::io::Result<()> {
frontend::bind(
portal,
name,
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol::IsNull::No)
}
None => Ok(postgres_protocol::IsNull::Yes),
},
Some(0), // all text
&mut self.raw.buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
frontend::BindError::Serialization(io) => io,
})?;
frontend::describe(b'P', portal, &mut self.raw.buf)?;
frontend::execute(portal, 0, &mut self.raw.buf)
}
pub async fn sync(&mut self) -> Result<(), Error> {
frontend::sync(&mut self.raw.buf);
self.raw.send().await
}
/// returns None if there's no row data
/// returns Some with the row description and a row stream if there is row data
pub async fn stream_query_results(
&mut self,
) -> Result<
Option<(
RowDescriptionBody,
impl Stream<Item = Result<DataRowBody, Error>> + '_,
)>,
Error,
> {
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
match self.raw.next_message().await? {
Message::RowDescription(desc) => {
struct RowStream<'a, S, T> {
raw: &'a mut RawConnection<S, T>,
}
impl<S, T> Unpin for RowStream<'_, S, T> {}
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
for RowStream<'_, S, T>
{
type Item = Result<DataRowBody, Error>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match ready!(self.raw.poll_read(cx)?) {
Message::DataRow(row) => Poll::Ready(Some(Ok(row))),
Message::CommandComplete(_) => Poll::Ready(None),
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
}
}
}
Ok(Some((desc, RowStream { raw: &mut self.raw })))
}
Message::NoData => {
let Message::CommandComplete(_) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) };
Ok(None)
}
_ => Err(Error::expecting("query results")),
}
}
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
loop {
match self.raw.next_message().await.unwrap() {
Message::ReadyForQuery(b) => break Ok(b),
_ => continue,
}
}
}
}

View File

@@ -2,11 +2,13 @@ use std::{error, fmt, io};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use tokio_native_tls::native_tls;
use tokio_postgres::error::{ErrorPosition, SqlState};
#[derive(Debug, PartialEq)]
enum Kind {
Io,
Tls,
UnexpectedMessage,
FromSql(usize),
Closed,
@@ -21,7 +23,7 @@ struct ErrorInner {
}
/// An error communicating with the Postgres server.
pub struct Error(Box<ErrorInner>);
pub struct Error(ErrorInner);
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -36,6 +38,7 @@ impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::Tls => fmt.write_str("error establishing tls")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Closed => fmt.write_str("connection closed")?,
@@ -56,6 +59,12 @@ impl error::Error for Error {
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::io(value)
}
}
impl Error {
/// Consumes the error, returning its cause.
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
@@ -82,7 +91,7 @@ impl Error {
}
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
Error(Box::new(ErrorInner { kind, cause }))
Error(ErrorInner { kind, cause })
}
#[allow(clippy::needless_pass_by_value)]
@@ -105,6 +114,10 @@ impl Error {
Error::new(Kind::UnexpectedMessage, None)
}
pub(crate) fn expecting(expected: &str) -> Error {
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
}
pub(crate) fn parse(e: io::Error) -> Error {
Error::new(Kind::Parse, Some(Box::new(e)))
}
@@ -116,6 +129,10 @@ impl Error {
pub(crate) fn io(e: io::Error) -> Error {
Error::new(Kind::Io, Some(Box::new(e)))
}
pub(crate) fn tls(e: native_tls::Error) -> Error {
Error::new(Kind::Tls, Some(Box::new(e)))
}
}
/// The severity of a Postgres error or notice.

View File

@@ -0,0 +1,6 @@
pub mod codec;
pub mod connection;
pub mod error;
// mod prepare;
// mod pg_type;