From 17627e80234506e87100d5c78182899d19404df2 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 27 Jul 2023 13:12:16 +0100 Subject: [PATCH] stash2 --- Cargo.lock | 1 + proxy/Cargo.toml | 1 + proxy/src/http/sql_over_http.rs | 83 +--- proxy/src/http/sql_over_http/codec.rs | 108 ----- proxy/src/http/sql_over_http/connection.rs | 194 --------- proxy/src/lib.rs | 1 + proxy/src/pg_client/codec.rs | 43 ++ proxy/src/pg_client/connection.rs | 369 ++++++++++++++++++ .../sql_over_http => pg_client}/error.rs | 21 +- proxy/src/pg_client/mod.rs | 6 + .../sql_over_http => pg_client}/pg_type.rs | 0 .../sql_over_http => pg_client}/prepare.rs | 0 12 files changed, 451 insertions(+), 376 deletions(-) delete mode 100644 proxy/src/http/sql_over_http/codec.rs delete mode 100644 proxy/src/http/sql_over_http/connection.rs create mode 100644 proxy/src/pg_client/codec.rs create mode 100644 proxy/src/pg_client/connection.rs rename proxy/src/{http/sql_over_http => pg_client}/error.rs (96%) create mode 100644 proxy/src/pg_client/mod.rs rename proxy/src/{http/sql_over_http => pg_client}/pg_type.rs (100%) rename proxy/src/{http/sql_over_http => pg_client}/prepare.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 8847c2095e..19fd369125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3074,6 +3074,7 @@ dependencies = [ "thiserror", "tls-listener", "tokio", + "tokio-native-tls", "tokio-postgres", "tokio-postgres-rustls", "tokio-rustls 0.23.4", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 8deeb44d2a..fbc41348e2 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -65,6 +65,7 @@ webpki-roots.workspace = true x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true +tokio-native-tls = "0.3.1" workspace_hack.workspace = true tokio-util.workspace = true diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index f12751537c..5423028df5 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -26,18 +26,14 @@ use tokio_postgres::RowStream; use tokio_postgres::Statement; use url::Url; -use crate::http::sql_over_http::codec::FrontendMessage; -use crate::http::sql_over_http::connection::RequestMessages; +use crate::pg_client; +use crate::pg_client::codec::FrontendMessage; +use crate::pg_client::connection; +use crate::pg_client::connection::RequestMessages; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; -mod codec; -mod connection; -mod error; -// mod prepare; -// mod pg_type; - #[derive(serde::Deserialize)] struct QueryData { query: String, @@ -374,75 +370,20 @@ async fn query_raw_txt<'a, St, T>( conn: &mut connection::Connection, query: String, params: Vec>, -) -> Result, error::Error> +) -> Result, pg_client::error::Error> where St: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { - use postgres_protocol::message::backend::Message; - use postgres_protocol::message::frontend; - - let params_len = params.len(); let params = params.into_iter(); - { - let buf = &mut conn.buf; - // Parse, anonymous portal - frontend::parse("", query.as_str(), std::iter::empty(), buf) - .map_err(error::Error::encode)?; - // Bind, pass params as text, retrieve as binary - match frontend::bind( - "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement - std::iter::empty(), // all parameters use the default format (text) - params, - |param, buf| match param { - Some(param) => { - buf.put_slice(param.as_bytes()); - Ok(postgres_protocol::IsNull::No) - } - None => Ok(postgres_protocol::IsNull::Yes), - }, - Some(0), // all text - buf, - ) { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(error::Error::encode( - std::io::Error::new(ErrorKind::Other, e), - )), - Err(frontend::BindError::Serialization(e)) => Err(error::Error::encode(e)), - }?; - - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(error::Error::encode)?; - // Execute - frontend::execute("", 0, buf).map_err(error::Error::encode)?; - // Sync - frontend::sync(buf); - } - - conn.send().await?; - - // now read the responses - - match conn.next_message().await? { - Message::ParseComplete => {} - _ => return Err(error::Error::unexpected_message()), - } - match conn.next_message().await? { - Message::BindComplete => {} - _ => return Err(error::Error::unexpected_message()), - } - let row_description = match conn.next_message().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(error::Error::unexpected_message()), - }; + conn.prepare_and_execute("", "", query.as_str(), params)?; + conn.sync().await?; let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(error::Error::parse)? { + if let Some((desc, rows)) = conn.stream_query_results().await? { + let mut it = desc.fields(); + while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? { let type_ = Type::from_oid(field.type_oid()); // let column = Column::new(field.name().to_string(), type_, field); columns.push(Column { @@ -452,8 +393,6 @@ where } } - // let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); - Ok(columns) } @@ -515,7 +454,7 @@ fn pg_text_row_to_json2( let pg_value = range .map(|r| { std::str::from_utf8(&row.buffer()[r]) - .map_err(|e| error::Error::from_sql(e.into(), i)) + .map_err(|e| pg_client::error::Error::from_sql(e.into(), i)) }) .transpose()?; // let pg_value = row.as_text(i)?; diff --git a/proxy/src/http/sql_over_http/codec.rs b/proxy/src/http/sql_over_http/codec.rs deleted file mode 100644 index 5419a89d2a..0000000000 --- a/proxy/src/http/sql_over_http/codec.rs +++ /dev/null @@ -1,108 +0,0 @@ -use bytes::{Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend; -use std::io; -use tokio_util::codec::{Decoder, Encoder}; - -pub enum FrontendMessage { - Raw(Bytes), - // CopyData(CopyData>), -} - -pub enum BackendMessage { - Normal { - messages: BackendMessages, - request_complete: bool, - }, - Async(backend::Message), -} - -pub struct BackendMessages(pub BytesMut); - -impl BackendMessages { - pub fn empty() -> BackendMessages { - BackendMessages(BytesMut::new()) - } -} - -impl FallibleIterator for BackendMessages { - type Item = backend::Message; - type Error = io::Error; - - fn next(&mut self) -> io::Result> { - backend::Message::parse(&mut self.0) - } -} - -pub struct PostgresCodec { - pub max_message_size: Option, -} - -impl Encoder for PostgresCodec { - type Error = io::Error; - - fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { - match item { - FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), - // FrontendMessage::CopyData(data) => data.write(dst), - } - - Ok(()) - } -} - -impl Decoder for PostgresCodec { - type Item = BackendMessage; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { - let mut idx = 0; - let mut request_complete = false; - - while let Some(header) = backend::Header::parse(&src[idx..])? { - let len = header.len() as usize + 1; - if src[idx..].len() < len { - break; - } - - if let Some(max) = self.max_message_size { - if len > max { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "message too large", - )); - } - } - - match header.tag() { - backend::NOTICE_RESPONSE_TAG - | backend::NOTIFICATION_RESPONSE_TAG - | backend::PARAMETER_STATUS_TAG => { - if idx == 0 { - let message = backend::Message::parse(src)?.unwrap(); - return Ok(Some(BackendMessage::Async(message))); - } else { - break; - } - } - _ => {} - } - - idx += len; - - if header.tag() == backend::READY_FOR_QUERY_TAG { - request_complete = true; - break; - } - } - - if idx == 0 { - Ok(None) - } else { - Ok(Some(BackendMessage::Normal { - messages: BackendMessages(src.split_to(idx)), - request_complete, - })) - } - } -} diff --git a/proxy/src/http/sql_over_http/connection.rs b/proxy/src/http/sql_over_http/connection.rs deleted file mode 100644 index 9c9db55543..0000000000 --- a/proxy/src/http/sql_over_http/connection.rs +++ /dev/null @@ -1,194 +0,0 @@ -use super::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use super::error::Error; -use bytes::BytesMut; -use fallible_iterator::FallibleIterator; -use futures::channel::mpsc; -use futures::SinkExt; -use futures::{Sink, StreamExt}; -use postgres_protocol::message::backend::Message; -use std::collections::{HashMap, VecDeque}; -use std::future::poll_fn; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::maybe_tls_stream::MaybeTlsStream; -use tokio_util::codec::Framed; -use tracing::trace; - -pub enum RequestMessages { - Single(FrontendMessage), -} - -pub struct Request { - pub messages: RequestMessages, - pub sender: mpsc::Sender, -} - -pub struct Response { - sender: mpsc::Sender, -} - -// #[derive(PartialEq, Debug)] -// enum State { -// Active, -// Terminating, -// Closing, -// } - -/// A connection to a PostgreSQL database. -/// -/// This is one half of what is returned when a new connection is established. It performs the actual IO with the -/// server, and should generally be spawned off onto an executor to run in the background. -/// -/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has -/// occurred, or because its associated `Client` has dropped and all outstanding work has completed. -#[must_use = "futures do nothing unless polled"] -pub struct Connection { - /// HACK: we need this in the Neon Proxy. - pub stream: Framed, PostgresCodec>, - /// HACK: we need this in the Neon Proxy to forward params. - pub parameters: HashMap, - // receiver: mpsc::UnboundedReceiver, - pending_request: Option, - pending_responses: VecDeque<(BackendMessages, bool)>, - pub buf: BytesMut, - // responses: VecDeque, - // state: State, -} - -impl Connection -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new( - stream: Framed, PostgresCodec>, - pending_responses: VecDeque<(BackendMessages, bool)>, - parameters: HashMap, - // receiver: mpsc::UnboundedReceiver, - ) -> Connection { - Connection { - stream, - parameters, - // receiver, - pending_request: None, - pending_responses, - buf: BytesMut::new(), - // responses: VecDeque::new(), - // state: State::Active, - } - } - - pub async fn send(&mut self) -> Result<(), Error> { - poll_fn(|cx| self.poll_send(cx)).await?; - let request = FrontendMessage::Raw(self.buf.split().freeze()); - self.stream.start_send_unpin(request).map_err(Error::io)?; - poll_fn(|cx| self.poll_flush(cx)).await - } - - pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> { - match self.pending_responses.pop_front() { - Some((a, b)) => Ok((a, b)), - None => poll_fn(|cx| self.poll_read(cx)).await, - } - } - - pub async fn next_message(&mut self) -> Result { - loop { - let (mut messages, complete) = self.next_response().await?; - if let Some(message) = messages.next().map_err(Error::parse)? { - self.pending_responses.push_front((messages, complete)); - break Ok(message); - } - if complete { - break Err(Error::unexpected_message()); - } - } - } - - fn poll_response( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - self.stream - .poll_next_unpin(cx) - .map(|o| o.map(|r| r.map_err(Error::io))) - } - - fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - let message = match ready!(self.poll_response(cx)?) { - Some(message) => message, - None => return Poll::Ready(Err(Error::closed())), - }; - - match message { - BackendMessage::Async(Message::NoticeResponse(body)) => { - // TODO: log this - - // let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; - // return Ok(Some(AsyncMessage::Notice(error))); - continue; - } - BackendMessage::Async(Message::NotificationResponse(body)) => { - // TODO: log this - - // let notification = Notification { - // process_id: body.process_id(), - // channel: body.channel().map_err(Error::parse)?.to_string(), - // payload: body.message().map_err(Error::parse)?.to_string(), - // }; - // return Ok(Some(AsyncMessage::Notification(notification))); - continue; - } - BackendMessage::Async(Message::ParameterStatus(body)) => { - self.parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - continue; - } - BackendMessage::Async(_) => unreachable!(), - BackendMessage::Normal { - messages, - request_complete, - } => return Poll::Ready(Ok((messages, request_complete))), - }; - } - } - - fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.stream) - .poll_close(cx) - .map_err(Error::io)? - { - Poll::Ready(()) => { - trace!("poll_shutdown: complete"); - Poll::Ready(Ok(())) - } - Poll::Pending => { - trace!("poll_shutdown: waiting on socket"); - Poll::Pending - } - } - } - - /// Returns the value of a runtime parameter for this connection. - pub fn parameter(&self, name: &str) -> Option<&str> { - self.parameters.get(name).map(|s| &**s) - } - - fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Poll::Ready(msg) = self.poll_read(cx)? { - self.pending_responses.push_back(msg); - }; - self.stream.poll_ready_unpin(cx).map_err(Error::io) - } - - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Poll::Ready(msg) = self.poll_read(cx)? { - self.pending_responses.push_back(msg); - }; - self.stream.poll_flush_unpin(cx).map_err(Error::io) - } -} diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 1e1e216bb7..2171396d0a 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -22,6 +22,7 @@ pub mod scram; pub mod stream; pub mod url; pub mod waiters; +pub mod pg_client; /// Handle unix signals appropriately. pub async fn handle_signals(token: CancellationToken) -> anyhow::Result { diff --git a/proxy/src/pg_client/codec.rs b/proxy/src/pg_client/codec.rs new file mode 100644 index 0000000000..fcde9725f7 --- /dev/null +++ b/proxy/src/pg_client/codec.rs @@ -0,0 +1,43 @@ +use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use postgres_protocol::message::backend::{self, Message}; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +pub struct FrontendMessage(pub Bytes); +pub struct BackendMessages(pub BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + +pub struct PostgresCodec; + +impl Encoder for PostgresCodec { + type Error = io::Error; + + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { + dst.extend_from_slice(&item.0); + Ok(()) + } +} + +impl Decoder for PostgresCodec { + type Item = Message; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + Message::parse(src) + } +} diff --git a/proxy/src/pg_client/connection.rs b/proxy/src/pg_client/connection.rs new file mode 100644 index 0000000000..7e5f0d88b3 --- /dev/null +++ b/proxy/src/pg_client/connection.rs @@ -0,0 +1,369 @@ +use super::codec::{BackendMessages, FrontendMessage, PostgresCodec}; +use super::error::Error; +use bytes::{BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures::channel::mpsc; +use futures::{Sink, StreamExt}; +use futures::{SinkExt, Stream}; +use postgres_protocol::authentication; +use postgres_protocol::message::backend::{ + BackendKeyDataBody, DataRowBody, Message, ReadyForQueryBody, RowDescriptionBody, +}; +use postgres_protocol::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::future::poll_fn; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio_native_tls::{native_tls, TlsConnector, TlsStream}; +use tokio_postgres::maybe_tls_stream::MaybeTlsStream; +use tokio_util::codec::Framed; + +pub enum RequestMessages { + Single(FrontendMessage), +} + +pub struct Request { + pub messages: RequestMessages, + pub sender: mpsc::Sender, +} + +pub struct Response { + sender: mpsc::Sender, +} + +/// A connection to a PostgreSQL database. +pub struct RawConnection { + stream: Framed, PostgresCodec>, + pending_responses: VecDeque, + pub buf: BytesMut, +} + +// enum MaybeTlsStream { +// NoTls(TcpStream), +// Tls(TlsStream), +// } + +// impl Unpin for MaybeTlsStream {} + +// impl AsyncRead for MaybeTlsStream { +// fn poll_read( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &mut tokio::io::ReadBuf<'_>, +// ) -> Poll> { +// match self.get_mut() { +// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_read(cx, buf), +// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf), +// } +// } +// } +// impl AsyncWrite for MaybeTlsStream { +// fn poll_write( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &[u8], +// ) -> Poll> { +// match self.get_mut() { +// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write(cx, buf), +// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf), +// } +// } + +// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// match self.get_mut() { +// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_flush(cx), +// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx), +// } +// } + +// fn poll_shutdown( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// ) -> Poll> { +// match self.get_mut() { +// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_shutdown(cx), +// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_shutdown(cx), +// } +// } + +// fn poll_write_vectored( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// bufs: &[std::io::IoSlice<'_>], +// ) -> Poll> { +// match self.get_mut() { +// MaybeTlsStream::NoTls(no_tls) => Pin::new(no_tls).poll_write_vectored(cx, bufs), +// MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write_vectored(cx, bufs), +// } +// } + +// fn is_write_vectored(&self) -> bool { +// match self { +// MaybeTlsStream::NoTls(no_tls) => no_tls.is_write_vectored(), +// MaybeTlsStream::Tls(tls) => tls.is_write_vectored(), +// } +// } +// } + +impl RawConnection { + // pub(crate) async fn connect( + // mut stream: TcpStream, + // tls_domain: Option<&str>, + // ) -> Result, Error> { + // let mut buf = BytesMut::new(); + + // let stream = if let Some(tls_domain) = tls_domain { + // frontend::ssl_request(&mut buf); + // stream + // .write_all_buf(&mut buf.split().freeze()) + // .await + // .unwrap(); + // let bit = stream.read_u8().await.map_err(Error::io)?; + // if bit != b'S' { + // return Err(Error::closed()); + // } + + // let tls = native_tls::TlsConnector::new().map_err(Error::tls)?; + // let tls = TlsConnector::from(tls) + // .connect(tls_domain, stream) + // .await + // .map_err(Error::tls)?; + + // MaybeTlsStream::Tls(tls) + // } else { + // MaybeTlsStream::Raw(stream) + // }; + + // Ok(RawConnection::new(Framed::new(stream, PostgresCodec), buf)) + // } + + pub fn new( + stream: Framed, PostgresCodec>, + buf: BytesMut, + ) -> RawConnection { + RawConnection { + stream, + pending_responses: VecDeque::new(), + buf, + } + } + + pub async fn send(&mut self) -> Result<(), Error> { + poll_fn(|cx| self.poll_send(cx)).await?; + let request = FrontendMessage(self.buf.split().freeze()); + self.stream.start_send_unpin(request).map_err(Error::io)?; + poll_fn(|cx| self.poll_flush(cx)).await + } + + pub async fn next_message(&mut self) -> Result { + match self.pending_responses.pop_front() { + Some(message) => Ok(message), + None => poll_fn(|cx| self.poll_read(cx)).await, + } + } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { + let message = match ready!(self.stream.poll_next_unpin(cx)?) { + Some(message) => message, + None => return Poll::Ready(Err(Error::closed())), + }; + Poll::Ready(Ok(message)) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io) + } + + fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(msg) = self.poll_read(cx)? { + self.pending_responses.push_back(msg); + }; + self.stream.poll_ready_unpin(cx).map_err(Error::io) + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(msg) = self.poll_read(cx)? { + self.pending_responses.push_back(msg); + }; + self.stream.poll_flush_unpin(cx).map_err(Error::io) + } +} + +pub struct Connection { + raw: RawConnection, + key: BackendKeyDataBody, +} + +impl Connection { + // pub async fn auth_sasl_scram<'a, I>( + // mut raw: RawConnection, + // params: I, + // password: &[u8], + // ) -> Result + // where + // I: IntoIterator, + // { + // // send a startup message + // frontend::startup_message(params, &mut raw.buf).unwrap(); + // raw.send().await?; + + // // expect sasl authentication message + // let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) }; + // // expect support for SCRAM_SHA_256 + // if body + // .mechanisms() + // .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))? + // .is_none() + // { + // return Err(Error::expecting("SCRAM-SHA-256 auth")); + // } + + // // initiate SCRAM_SHA_256 authentication without channel binding + // let auth = authentication::sasl::ChannelBinding::unrequested(); + // let mut scram = authentication::sasl::ScramSha256::new(password, auth); + + // frontend::sasl_initial_response( + // authentication::sasl::SCRAM_SHA_256, + // scram.message(), + // &mut raw.buf, + // ) + // .unwrap(); + // raw.send().await?; + + // // expect sasl continue + // let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) }; + // scram.update(b.data()).unwrap(); + + // // continue sasl + // frontend::sasl_response(scram.message(), &mut raw.buf).unwrap(); + // raw.send().await?; + + // // expect sasl final + // let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) }; + // scram.finish(b.data()).unwrap(); + + // // expect auth ok + // let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) }; + + // // expect connection accepted + // let key = loop { + // match raw.next_message().await? { + // Message::BackendKeyData(key) => break key, + // Message::ParameterStatus(_) => {} + // _ => return Err(Error::expecting("backend ready")), + // } + // }; + + // let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) }; + // // assert_eq!(b.status(), b'I'); + + // Ok(Self { raw, key }) + // } + + pub fn prepare_and_execute( + &mut self, + portal: &str, + name: &str, + query: &str, + params: impl IntoIterator>>, + ) -> std::io::Result<()> { + self.prepare(name, query)?; + self.execute(portal, name, params) + } + + pub fn prepare(&mut self, name: &str, query: &str) -> std::io::Result<()> { + frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf) + } + + pub fn execute( + &mut self, + portal: &str, + name: &str, + params: impl IntoIterator>>, + ) -> std::io::Result<()> { + frontend::bind( + portal, + name, + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + &mut self.raw.buf, + ) + .map_err(|e| match e { + frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e), + frontend::BindError::Serialization(io) => io, + })?; + frontend::describe(b'P', portal, &mut self.raw.buf)?; + frontend::execute(portal, 0, &mut self.raw.buf) + } + + pub async fn sync(&mut self) -> Result<(), Error> { + frontend::sync(&mut self.raw.buf); + self.raw.send().await + } + + /// returns None if there's no row data + /// returns Some with the row description and a row stream if there is row data + pub async fn stream_query_results( + &mut self, + ) -> Result< + Option<( + RowDescriptionBody, + impl Stream> + '_, + )>, + Error, + > { + let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) }; + let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) }; + match self.raw.next_message().await? { + Message::RowDescription(desc) => { + struct RowStream<'a, S, T> { + raw: &'a mut RawConnection, + } + impl Unpin for RowStream<'_, S, T> {} + + impl Stream + for RowStream<'_, S, T> + { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.raw.poll_read(cx)?) { + Message::DataRow(row) => Poll::Ready(Some(Ok(row))), + Message::CommandComplete(_) => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::expecting("command completion")))), + } + } + } + + Ok(Some((desc, RowStream { raw: &mut self.raw }))) + } + Message::NoData => { + let Message::CommandComplete(_) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) }; + Ok(None) + } + _ => Err(Error::expecting("query results")), + } + } + + pub async fn wait_for_ready(&mut self) -> Result { + loop { + match self.raw.next_message().await.unwrap() { + Message::ReadyForQuery(b) => break Ok(b), + _ => continue, + } + } + } +} diff --git a/proxy/src/http/sql_over_http/error.rs b/proxy/src/pg_client/error.rs similarity index 96% rename from proxy/src/http/sql_over_http/error.rs rename to proxy/src/pg_client/error.rs index 5fb9a5979a..684d9f4646 100644 --- a/proxy/src/http/sql_over_http/error.rs +++ b/proxy/src/pg_client/error.rs @@ -2,11 +2,13 @@ use std::{error, fmt, io}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; +use tokio_native_tls::native_tls; use tokio_postgres::error::{ErrorPosition, SqlState}; #[derive(Debug, PartialEq)] enum Kind { Io, + Tls, UnexpectedMessage, FromSql(usize), Closed, @@ -21,7 +23,7 @@ struct ErrorInner { } /// An error communicating with the Postgres server. -pub struct Error(Box); +pub struct Error(ErrorInner); impl fmt::Debug for Error { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -36,6 +38,7 @@ impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0.kind { Kind::Io => fmt.write_str("error communicating with the server")?, + Kind::Tls => fmt.write_str("error establishing tls")?, Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, Kind::Closed => fmt.write_str("connection closed")?, @@ -56,6 +59,12 @@ impl error::Error for Error { } } +impl From for Error { + fn from(value: io::Error) -> Self { + Self::io(value) + } +} + impl Error { /// Consumes the error, returning its cause. pub fn into_source(self) -> Option> { @@ -82,7 +91,7 @@ impl Error { } fn new(kind: Kind, cause: Option>) -> Error { - Error(Box::new(ErrorInner { kind, cause })) + Error(ErrorInner { kind, cause }) } #[allow(clippy::needless_pass_by_value)] @@ -105,6 +114,10 @@ impl Error { Error::new(Kind::UnexpectedMessage, None) } + pub(crate) fn expecting(expected: &str) -> Error { + Error::new(Kind::UnexpectedMessage, Some(expected.into())) + } + pub(crate) fn parse(e: io::Error) -> Error { Error::new(Kind::Parse, Some(Box::new(e))) } @@ -116,6 +129,10 @@ impl Error { pub(crate) fn io(e: io::Error) -> Error { Error::new(Kind::Io, Some(Box::new(e))) } + + pub(crate) fn tls(e: native_tls::Error) -> Error { + Error::new(Kind::Tls, Some(Box::new(e))) + } } /// The severity of a Postgres error or notice. diff --git a/proxy/src/pg_client/mod.rs b/proxy/src/pg_client/mod.rs new file mode 100644 index 0000000000..c0d3c33a4a --- /dev/null +++ b/proxy/src/pg_client/mod.rs @@ -0,0 +1,6 @@ + +pub mod codec; +pub mod connection; +pub mod error; +// mod prepare; +// mod pg_type; diff --git a/proxy/src/http/sql_over_http/pg_type.rs b/proxy/src/pg_client/pg_type.rs similarity index 100% rename from proxy/src/http/sql_over_http/pg_type.rs rename to proxy/src/pg_client/pg_type.rs diff --git a/proxy/src/http/sql_over_http/prepare.rs b/proxy/src/pg_client/prepare.rs similarity index 100% rename from proxy/src/http/sql_over_http/prepare.rs rename to proxy/src/pg_client/prepare.rs