mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-24 16:40:38 +00:00
stash2
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3074,6 +3074,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tls-listener",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.23.4",
|
||||
|
||||
@@ -65,6 +65,7 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
tokio-native-tls = "0.3.1"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
@@ -26,18 +26,14 @@ use tokio_postgres::RowStream;
|
||||
use tokio_postgres::Statement;
|
||||
use url::Url;
|
||||
|
||||
use crate::http::sql_over_http::codec::FrontendMessage;
|
||||
use crate::http::sql_over_http::connection::RequestMessages;
|
||||
use crate::pg_client;
|
||||
use crate::pg_client::codec::FrontendMessage;
|
||||
use crate::pg_client::connection;
|
||||
use crate::pg_client::connection::RequestMessages;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
|
||||
mod codec;
|
||||
mod connection;
|
||||
mod error;
|
||||
// mod prepare;
|
||||
// mod pg_type;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
@@ -374,75 +370,20 @@ async fn query_raw_txt<'a, St, T>(
|
||||
conn: &mut connection::Connection<St, T>,
|
||||
query: String,
|
||||
params: Vec<Option<String>>,
|
||||
) -> Result<Vec<Column>, error::Error>
|
||||
) -> Result<Vec<Column>, pg_client::error::Error>
|
||||
where
|
||||
St: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
|
||||
let params_len = params.len();
|
||||
let params = params.into_iter();
|
||||
|
||||
{
|
||||
let buf = &mut conn.buf;
|
||||
// Parse, anonymous portal
|
||||
frontend::parse("", query.as_str(), std::iter::empty(), buf)
|
||||
.map_err(error::Error::encode)?;
|
||||
// Bind, pass params as text, retrieve as binary
|
||||
match frontend::bind(
|
||||
"", // empty string selects the unnamed portal
|
||||
"", // empty string selects the unnamed prepared statement
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_bytes());
|
||||
Ok(postgres_protocol::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol::IsNull::Yes),
|
||||
},
|
||||
Some(0), // all text
|
||||
buf,
|
||||
) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(error::Error::encode(
|
||||
std::io::Error::new(ErrorKind::Other, e),
|
||||
)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(error::Error::encode(e)),
|
||||
}?;
|
||||
|
||||
// Describe portal to typecast results
|
||||
frontend::describe(b'P', "", buf).map_err(error::Error::encode)?;
|
||||
// Execute
|
||||
frontend::execute("", 0, buf).map_err(error::Error::encode)?;
|
||||
// Sync
|
||||
frontend::sync(buf);
|
||||
}
|
||||
|
||||
conn.send().await?;
|
||||
|
||||
// now read the responses
|
||||
|
||||
match conn.next_message().await? {
|
||||
Message::ParseComplete => {}
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
}
|
||||
match conn.next_message().await? {
|
||||
Message::BindComplete => {}
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
}
|
||||
let row_description = match conn.next_message().await? {
|
||||
Message::RowDescription(body) => Some(body),
|
||||
Message::NoData => None,
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
};
|
||||
conn.prepare_and_execute("", "", query.as_str(), params)?;
|
||||
conn.sync().await?;
|
||||
|
||||
let mut columns = vec![];
|
||||
if let Some(row_description) = row_description {
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(error::Error::parse)? {
|
||||
if let Some((desc, rows)) = conn.stream_query_results().await? {
|
||||
let mut it = desc.fields();
|
||||
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
|
||||
let type_ = Type::from_oid(field.type_oid());
|
||||
// let column = Column::new(field.name().to_string(), type_, field);
|
||||
columns.push(Column {
|
||||
@@ -452,8 +393,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
@@ -515,7 +454,7 @@ fn pg_text_row_to_json2(
|
||||
let pg_value = range
|
||||
.map(|r| {
|
||||
std::str::from_utf8(&row.buffer()[r])
|
||||
.map_err(|e| error::Error::from_sql(e.into(), i))
|
||||
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
|
||||
})
|
||||
.transpose()?;
|
||||
// let pg_value = row.as_text(i)?;
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend;
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Bytes),
|
||||
// CopyData(CopyData<Box<dyn Buf + Send>>),
|
||||
}
|
||||
|
||||
pub enum BackendMessage {
|
||||
Normal {
|
||||
messages: BackendMessages,
|
||||
request_complete: bool,
|
||||
},
|
||||
Async(backend::Message),
|
||||
}
|
||||
|
||||
pub struct BackendMessages(pub BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
BackendMessages(BytesMut::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl FallibleIterator for BackendMessages {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec {
|
||||
pub max_message_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||
match item {
|
||||
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
|
||||
// FrontendMessage::CopyData(data) => data.write(dst),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = BackendMessage;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
|
||||
let mut idx = 0;
|
||||
let mut request_complete = false;
|
||||
|
||||
while let Some(header) = backend::Header::parse(&src[idx..])? {
|
||||
let len = header.len() as usize + 1;
|
||||
if src[idx..].len() < len {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(max) = self.max_message_size {
|
||||
if len > max {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"message too large",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match header.tag() {
|
||||
backend::NOTICE_RESPONSE_TAG
|
||||
| backend::NOTIFICATION_RESPONSE_TAG
|
||||
| backend::PARAMETER_STATUS_TAG => {
|
||||
if idx == 0 {
|
||||
let message = backend::Message::parse(src)?.unwrap();
|
||||
return Ok(Some(BackendMessage::Async(message)));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
idx += len;
|
||||
|
||||
if header.tag() == backend::READY_FOR_QUERY_TAG {
|
||||
request_complete = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if idx == 0 {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(BackendMessage::Normal {
|
||||
messages: BackendMessages(src.split_to(idx)),
|
||||
request_complete,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
use super::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use super::error::Error;
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::SinkExt;
|
||||
use futures::{Sink, StreamExt};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::poll_fn;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::trace;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
// #[derive(PartialEq, Debug)]
|
||||
// enum State {
|
||||
// Active,
|
||||
// Terminating,
|
||||
// Closing,
|
||||
// }
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
///
|
||||
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
|
||||
/// server, and should generally be spawned off onto an executor to run in the background.
|
||||
///
|
||||
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
|
||||
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connection<S, T> {
|
||||
/// HACK: we need this in the Neon Proxy.
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
/// HACK: we need this in the Neon Proxy to forward params.
|
||||
pub parameters: HashMap<String, String>,
|
||||
// receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_responses: VecDeque<(BackendMessages, bool)>,
|
||||
pub buf: BytesMut,
|
||||
// responses: VecDeque<Response>,
|
||||
// state: State,
|
||||
}
|
||||
|
||||
impl<S, T> Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<(BackendMessages, bool)>,
|
||||
parameters: HashMap<String, String>,
|
||||
// receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S, T> {
|
||||
Connection {
|
||||
stream,
|
||||
parameters,
|
||||
// receiver,
|
||||
pending_request: None,
|
||||
pending_responses,
|
||||
buf: BytesMut::new(),
|
||||
// responses: VecDeque::new(),
|
||||
// state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||
let request = FrontendMessage::Raw(self.buf.split().freeze());
|
||||
self.stream.start_send_unpin(request).map_err(Error::io)?;
|
||||
poll_fn(|cx| self.poll_flush(cx)).await
|
||||
}
|
||||
|
||||
pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> {
|
||||
match self.pending_responses.pop_front() {
|
||||
Some((a, b)) => Ok((a, b)),
|
||||
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||
loop {
|
||||
let (mut messages, complete) = self.next_response().await?;
|
||||
if let Some(message) = messages.next().map_err(Error::parse)? {
|
||||
self.pending_responses.push_front((messages, complete));
|
||||
break Ok(message);
|
||||
}
|
||||
if complete {
|
||||
break Err(Error::unexpected_message());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_response(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
self.stream
|
||||
.poll_next_unpin(cx)
|
||||
.map(|o| o.map(|r| r.map_err(Error::io)))
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<(BackendMessages, bool), Error>> {
|
||||
loop {
|
||||
let message = match ready!(self.poll_response(cx)?) {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
};
|
||||
|
||||
match message {
|
||||
BackendMessage::Async(Message::NoticeResponse(body)) => {
|
||||
// TODO: log this
|
||||
|
||||
// let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
|
||||
// return Ok(Some(AsyncMessage::Notice(error)));
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(Message::NotificationResponse(body)) => {
|
||||
// TODO: log this
|
||||
|
||||
// let notification = Notification {
|
||||
// process_id: body.process_id(),
|
||||
// channel: body.channel().map_err(Error::parse)?.to_string(),
|
||||
// payload: body.message().map_err(Error::parse)?.to_string(),
|
||||
// };
|
||||
// return Ok(Some(AsyncMessage::Notification(notification)));
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(Message::ParameterStatus(body)) => {
|
||||
self.parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(_) => unreachable!(),
|
||||
BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
} => return Poll::Ready(Ok((messages, request_complete))),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_close(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_shutdown: complete");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_shutdown: waiting on socket");
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value of a runtime parameter for this connection.
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
||||
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ pub mod scram;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
pub mod waiters;
|
||||
pub mod pg_client;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||
|
||||
43
proxy/src/pg_client/codec.rs
Normal file
43
proxy/src/pg_client/codec.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{self, Message};
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub struct FrontendMessage(pub Bytes);
|
||||
pub struct BackendMessages(pub BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
BackendMessages(BytesMut::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl FallibleIterator for BackendMessages {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(&item.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
|
||||
Message::parse(src)
|
||||
}
|
||||
}
|
||||
369
proxy/src/pg_client/connection.rs
Normal file
369
proxy/src/pg_client/connection.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use super::error::Error;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{Sink, StreamExt};
|
||||
use futures::{SinkExt, Stream};
|
||||
use postgres_protocol::authentication;
|
||||
use postgres_protocol::message::backend::{
|
||||
BackendKeyDataBody, DataRowBody, Message, ReadyForQueryBody, RowDescriptionBody,
|
||||
};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::poll_fn;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
|
||||
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
pub struct RawConnection<S, T> {
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<Message>,
|
||||
pub buf: BytesMut,
|
||||
}
|
||||
|
||||
// enum MaybeTlsStream {
|
||||
// NoTls(TcpStream),
|
||||
// Tls(TlsStream<TcpStream>),
|
||||
// }
|
||||
|
||||
// impl Unpin for MaybeTlsStream {}
|
||||
|
||||
// impl AsyncRead for MaybeTlsStream {
|
||||
// fn poll_read(
|
||||
// self: Pin<&mut Self>,
|
||||
// cx: &mut Context<'_>,
|
||||
// buf: &mut tokio::io::ReadBuf<'_>,
|
||||
// ) -> Poll<std::io::Result<()>> {
|
||||
// match self.get_mut() {
|
||||
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_read(cx, buf),
|
||||
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// impl AsyncWrite for MaybeTlsStream {
|
||||
// fn poll_write(
|
||||
// self: Pin<&mut Self>,
|
||||
// cx: &mut Context<'_>,
|
||||
// buf: &[u8],
|
||||
// ) -> Poll<Result<usize, std::io::Error>> {
|
||||
// match self.get_mut() {
|
||||
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write(cx, buf),
|
||||
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
// match self.get_mut() {
|
||||
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_flush(cx),
|
||||
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn poll_shutdown(
|
||||
// self: Pin<&mut Self>,
|
||||
// cx: &mut Context<'_>,
|
||||
// ) -> Poll<Result<(), std::io::Error>> {
|
||||
// match self.get_mut() {
|
||||
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_shutdown(cx),
|
||||
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_shutdown(cx),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn poll_write_vectored(
|
||||
// self: Pin<&mut Self>,
|
||||
// cx: &mut Context<'_>,
|
||||
// bufs: &[std::io::IoSlice<'_>],
|
||||
// ) -> Poll<Result<usize, std::io::Error>> {
|
||||
// match self.get_mut() {
|
||||
// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write_vectored(cx, bufs),
|
||||
// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write_vectored(cx, bufs),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn is_write_vectored(&self) -> bool {
|
||||
// match self {
|
||||
// MaybeTlsStream::NoTls(no_tls) => no_tls.is_write_vectored(),
|
||||
// MaybeTlsStream::Tls(tls) => tls.is_write_vectored(),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
|
||||
// pub(crate) async fn connect(
|
||||
// mut stream: TcpStream,
|
||||
// tls_domain: Option<&str>,
|
||||
// ) -> Result<RawConnection<S, T>, Error> {
|
||||
// let mut buf = BytesMut::new();
|
||||
|
||||
// let stream = if let Some(tls_domain) = tls_domain {
|
||||
// frontend::ssl_request(&mut buf);
|
||||
// stream
|
||||
// .write_all_buf(&mut buf.split().freeze())
|
||||
// .await
|
||||
// .unwrap();
|
||||
// let bit = stream.read_u8().await.map_err(Error::io)?;
|
||||
// if bit != b'S' {
|
||||
// return Err(Error::closed());
|
||||
// }
|
||||
|
||||
// let tls = native_tls::TlsConnector::new().map_err(Error::tls)?;
|
||||
// let tls = TlsConnector::from(tls)
|
||||
// .connect(tls_domain, stream)
|
||||
// .await
|
||||
// .map_err(Error::tls)?;
|
||||
|
||||
// MaybeTlsStream::Tls(tls)
|
||||
// } else {
|
||||
// MaybeTlsStream::Raw(stream)
|
||||
// };
|
||||
|
||||
// Ok(RawConnection::new(Framed::new(stream, PostgresCodec), buf))
|
||||
// }
|
||||
|
||||
pub fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BytesMut,
|
||||
) -> RawConnection<S, T> {
|
||||
RawConnection {
|
||||
stream,
|
||||
pending_responses: VecDeque::new(),
|
||||
buf,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||
let request = FrontendMessage(self.buf.split().freeze());
|
||||
self.stream.start_send_unpin(request).map_err(Error::io)?;
|
||||
poll_fn(|cx| self.poll_flush(cx)).await
|
||||
}
|
||||
|
||||
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||
match self.pending_responses.pop_front() {
|
||||
Some(message) => Ok(message),
|
||||
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
|
||||
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
};
|
||||
Poll::Ready(Ok(message))
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Connection<S, T> {
|
||||
raw: RawConnection<S, T>,
|
||||
key: BackendKeyDataBody,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
|
||||
// pub async fn auth_sasl_scram<'a, I>(
|
||||
// mut raw: RawConnection<S, T>,
|
||||
// params: I,
|
||||
// password: &[u8],
|
||||
// ) -> Result<Self, Error>
|
||||
// where
|
||||
// I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||
// {
|
||||
// // send a startup message
|
||||
// frontend::startup_message(params, &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl authentication message
|
||||
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
|
||||
// // expect support for SCRAM_SHA_256
|
||||
// if body
|
||||
// .mechanisms()
|
||||
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
|
||||
// .is_none()
|
||||
// {
|
||||
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
|
||||
// }
|
||||
|
||||
// // initiate SCRAM_SHA_256 authentication without channel binding
|
||||
// let auth = authentication::sasl::ChannelBinding::unrequested();
|
||||
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
|
||||
|
||||
// frontend::sasl_initial_response(
|
||||
// authentication::sasl::SCRAM_SHA_256,
|
||||
// scram.message(),
|
||||
// &mut raw.buf,
|
||||
// )
|
||||
// .unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl continue
|
||||
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
|
||||
// scram.update(b.data()).unwrap();
|
||||
|
||||
// // continue sasl
|
||||
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl final
|
||||
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
|
||||
// scram.finish(b.data()).unwrap();
|
||||
|
||||
// // expect auth ok
|
||||
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
|
||||
|
||||
// // expect connection accepted
|
||||
// let key = loop {
|
||||
// match raw.next_message().await? {
|
||||
// Message::BackendKeyData(key) => break key,
|
||||
// Message::ParameterStatus(_) => {}
|
||||
// _ => return Err(Error::expecting("backend ready")),
|
||||
// }
|
||||
// };
|
||||
|
||||
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
|
||||
// // assert_eq!(b.status(), b'I');
|
||||
|
||||
// Ok(Self { raw, key })
|
||||
// }
|
||||
|
||||
pub fn prepare_and_execute(
|
||||
&mut self,
|
||||
portal: &str,
|
||||
name: &str,
|
||||
query: &str,
|
||||
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
) -> std::io::Result<()> {
|
||||
self.prepare(name, query)?;
|
||||
self.execute(portal, name, params)
|
||||
}
|
||||
|
||||
pub fn prepare(&mut self, name: &str, query: &str) -> std::io::Result<()> {
|
||||
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)
|
||||
}
|
||||
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
portal: &str,
|
||||
name: &str,
|
||||
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
) -> std::io::Result<()> {
|
||||
frontend::bind(
|
||||
portal,
|
||||
name,
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol::IsNull::Yes),
|
||||
},
|
||||
Some(0), // all text
|
||||
&mut self.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::describe(b'P', portal, &mut self.raw.buf)?;
|
||||
frontend::execute(portal, 0, &mut self.raw.buf)
|
||||
}
|
||||
|
||||
pub async fn sync(&mut self) -> Result<(), Error> {
|
||||
frontend::sync(&mut self.raw.buf);
|
||||
self.raw.send().await
|
||||
}
|
||||
|
||||
/// returns None if there's no row data
|
||||
/// returns Some with the row description and a row stream if there is row data
|
||||
pub async fn stream_query_results(
|
||||
&mut self,
|
||||
) -> Result<
|
||||
Option<(
|
||||
RowDescriptionBody,
|
||||
impl Stream<Item = Result<DataRowBody, Error>> + '_,
|
||||
)>,
|
||||
Error,
|
||||
> {
|
||||
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
|
||||
match self.raw.next_message().await? {
|
||||
Message::RowDescription(desc) => {
|
||||
struct RowStream<'a, S, T> {
|
||||
raw: &'a mut RawConnection<S, T>,
|
||||
}
|
||||
impl<S, T> Unpin for RowStream<'_, S, T> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
|
||||
for RowStream<'_, S, T>
|
||||
{
|
||||
type Item = Result<DataRowBody, Error>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.raw.poll_read(cx)?) {
|
||||
Message::DataRow(row) => Poll::Ready(Some(Ok(row))),
|
||||
Message::CommandComplete(_) => Poll::Ready(None),
|
||||
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some((desc, RowStream { raw: &mut self.raw })))
|
||||
}
|
||||
Message::NoData => {
|
||||
let Message::CommandComplete(_) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) };
|
||||
Ok(None)
|
||||
}
|
||||
_ => Err(Error::expecting("query results")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
loop {
|
||||
match self.raw.next_message().await.unwrap() {
|
||||
Message::ReadyForQuery(b) => break Ok(b),
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,13 @@ use std::{error, fmt, io};
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use tokio_native_tls::native_tls;
|
||||
use tokio_postgres::error::{ErrorPosition, SqlState};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Kind {
|
||||
Io,
|
||||
Tls,
|
||||
UnexpectedMessage,
|
||||
FromSql(usize),
|
||||
Closed,
|
||||
@@ -21,7 +23,7 @@ struct ErrorInner {
|
||||
}
|
||||
|
||||
/// An error communicating with the Postgres server.
|
||||
pub struct Error(Box<ErrorInner>);
|
||||
pub struct Error(ErrorInner);
|
||||
|
||||
impl fmt::Debug for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
@@ -36,6 +38,7 @@ impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.0.kind {
|
||||
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||
Kind::Tls => fmt.write_str("error establishing tls")?,
|
||||
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
@@ -56,6 +59,12 @@ impl error::Error for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Consumes the error, returning its cause.
|
||||
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
|
||||
@@ -82,7 +91,7 @@ impl Error {
|
||||
}
|
||||
|
||||
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
|
||||
Error(Box::new(ErrorInner { kind, cause }))
|
||||
Error(ErrorInner { kind, cause })
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
@@ -105,6 +114,10 @@ impl Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
pub(crate) fn expecting(expected: &str) -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn parse(e: io::Error) -> Error {
|
||||
Error::new(Kind::Parse, Some(Box::new(e)))
|
||||
}
|
||||
@@ -116,6 +129,10 @@ impl Error {
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn tls(e: native_tls::Error) -> Error {
|
||||
Error::new(Kind::Tls, Some(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
/// The severity of a Postgres error or notice.
|
||||
6
proxy/src/pg_client/mod.rs
Normal file
6
proxy/src/pg_client/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
|
||||
pub mod codec;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
// mod prepare;
|
||||
// mod pg_type;
|
||||
Reference in New Issue
Block a user