diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 733f685fa9..9bae4d5727 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -14,7 +14,6 @@ use tokio::sync::mpsc; use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::{Host, SslMode}; -use crate::connection::{Request, RequestMessages}; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; @@ -24,19 +23,43 @@ use crate::{ }; pub struct Responses { + /// new messages from conn receiver: mpsc::Receiver, + /// current batch of messages cur: BackendMessages, + /// number of total queries sent. + waiting: usize, + /// number of ReadyForQuery messages received. + received: usize, } impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.cur.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), - Some(message) => return Poll::Ready(Ok(message)), - None => {} + // get the next saved message + if let Some(message) = self.cur.next().map_err(Error::parse)? { + let received = self.received; + + // increase the query head if this is the last message. + if let Message::ReadyForQuery(_) = message { + self.received += 1; + } + + // check if the client has skipped this query. + if received + 1 < self.waiting { + // grab the next message. + continue; + } + + // convenience: turn the error messaage into a proper error. + let res = match message { + Message::ErrorResponse(body) => Err(Error::db(body)), + message => Ok(message), + }; + return Poll::Ready(res); } + // get the next back of messages. match ready!(self.receiver.poll_recv(cx)) { Some(messages) => self.cur = messages, None => return Poll::Ready(Err(Error::closed())), @@ -63,22 +86,18 @@ pub(crate) struct CachedTypeInfo { } pub struct InnerClient { - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + responses: Responses, /// A buffer to use when writing out postgres commands. buffer: BytesMut, } impl InnerClient { - pub fn send(&self, messages: RequestMessages) -> Result { - let (sender, receiver) = mpsc::channel(1); - let request = Request { messages, sender }; - self.sender.send(request).map_err(|_| Error::closed())?; - - Ok(Responses { - receiver, - cur: BackendMessages::empty(), - }) + pub fn send(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> { + self.sender.send(messages).map_err(|_| Error::closed())?; + self.responses.waiting += 1; + Ok(&mut self.responses) } /// Call the given function with a buffer to be used when writing out @@ -123,16 +142,15 @@ impl Drop for Client { frontend::sync(buf); buf.split().freeze() }); - let _ = self - .inner - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.inner.send(FrontendMessage::Raw(buf)); } } } impl Client { pub(crate) fn new( - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, socket_config: SocketConfig, ssl_mode: SslMode, process_id: i32, @@ -141,6 +159,12 @@ impl Client { Client { inner: InnerClient { sender, + responses: Responses { + receiver, + cur: BackendMessages::empty(), + waiting: 0, + received: 0, + }, buffer: Default::default(), }, cached_typeinfo: Default::default(), @@ -241,10 +265,7 @@ impl Client { frontend::query("ROLLBACK", buf).unwrap(); buf.split().freeze() }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner().send(FrontendMessage::Raw(buf)); } } diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index f1fd9b47b3..daa5371426 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,21 +1,16 @@ use std::io; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; -use postgres_protocol2::message::frontend::CopyData; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { Raw(Bytes), - CopyData(CopyData>), } pub enum BackendMessage { - Normal { - messages: BackendMessages, - request_complete: bool, - }, + Normal { messages: BackendMessages }, Async(backend::Message), } @@ -44,7 +39,6 @@ impl Encoder for PostgresCodec { fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { match item { FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), - FrontendMessage::CopyData(data) => data.write(dst), } Ok(()) @@ -57,7 +51,6 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; - let mut request_complete = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; @@ -82,7 +75,6 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { - request_complete = true; break; } } @@ -92,7 +84,6 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), - request_complete, })) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 7c3a358bba..39a0a87c74 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -59,9 +59,11 @@ where connect_timeout: config.connect_timeout, }; - let (sender, receiver) = mpsc::unbounded_channel(); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); + let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( - sender, + client_tx, + client_rx, socket_config, config.ssl_mode, process_id, @@ -74,7 +76,7 @@ where .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) .collect(); - let connection = Connection::new(stream, delayed, parameters, receiver); + let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index 99d6f3f8e2..fe0372b266 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; use bytes::BytesMut; -use fallible_iterator::FallibleIterator; use futures_util::{Sink, Stream, ready}; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; @@ -19,30 +18,12 @@ use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; use crate::{AsyncMessage, Error, Notification}; -pub enum RequestMessages { - Single(FrontendMessage), -} - -pub struct Request { - pub messages: RequestMessages, - pub sender: mpsc::Sender, -} - -pub struct Response { - sender: PollSender, -} - #[derive(PartialEq, Debug)] enum State { Active, Closing, } -enum WriteReady { - Terminating, - WaitingOnRead, -} - /// A connection to a PostgreSQL database. /// /// This is one half of what is returned when a new connection is established. It performs the actual IO with the @@ -56,9 +37,11 @@ pub struct Connection { pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + + sender: PollSender, + receiver: mpsc::UnboundedReceiver, + pending_responses: VecDeque, - responses: VecDeque, state: State, } @@ -71,14 +54,15 @@ where stream: Framed, PostgresCodec>, pending_responses: VecDeque, parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + sender: mpsc::Sender, + receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, parameters, + sender: PollSender::new(sender), receiver, pending_responses, - responses: VecDeque::new(), state: State::Active, } } @@ -110,7 +94,7 @@ where } }; - let (mut messages, request_complete) = match message { + let messages = match message { BackendMessage::Async(Message::NoticeResponse(body)) => { let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; return Poll::Ready(Ok(AsyncMessage::Notice(error))); @@ -131,41 +115,19 @@ where continue; } BackendMessage::Async(_) => unreachable!(), - BackendMessage::Normal { - messages, - request_complete, - } => (messages, request_complete), + BackendMessage::Normal { messages } => messages, }; - let mut response = match self.responses.pop_front() { - Some(response) => response, - None => match messages.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(error)) => { - return Poll::Ready(Err(Error::db(error))); - } - _ => return Poll::Ready(Err(Error::unexpected_message())), - }, - }; - - match response.sender.poll_reserve(cx) { + match self.sender.poll_reserve(cx) { Poll::Ready(Ok(())) => { - let _ = response.sender.send_item(messages); - if !request_complete { - self.responses.push_front(response); - } + let _ = self.sender.send_item(messages); } Poll::Ready(Err(_)) => { - // we need to keep paging through the rest of the messages even if the receiver's hung up - if !request_complete { - self.responses.push_front(response); - } + return Poll::Ready(Err(Error::closed())); } Poll::Pending => { - self.responses.push_front(response); - self.pending_responses.push_back(BackendMessage::Normal { - messages, - request_complete, - }); + self.pending_responses + .push_back(BackendMessage::Normal { messages }); trace!("poll_read: waiting on sender"); return Poll::Pending; } @@ -174,7 +136,7 @@ where } /// Fetch the next client request and enqueue the response sender. - fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_closed() { return Poll::Ready(None); } @@ -182,10 +144,7 @@ where match self.receiver.poll_recv(cx) { Poll::Ready(Some(request)) => { trace!("polled new request"); - self.responses.push_back(Response { - sender: PollSender::new(request.sender), - }); - Poll::Ready(Some(request.messages)) + Poll::Ready(Some(request)) } Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, @@ -194,7 +153,7 @@ where /// Process client requests and write them to the postgres connection, flushing if necessary. /// client -> postgres - fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if Pin::new(&mut self.stream) .poll_ready(cx) @@ -209,14 +168,14 @@ where match self.poll_request(cx) { // send the message to postgres - Poll::Ready(Some(RequestMessages::Single(request))) => { + Poll::Ready(Some(request)) => { Pin::new(&mut self.stream) .start_send(request) .map_err(Error::io)?; } // No more messages from the client, and no more responses to wait for. // Send a terminate message to postgres - Poll::Ready(None) if self.responses.is_empty() => { + Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); let mut request = BytesMut::new(); frontend::terminate(&mut request); @@ -228,16 +187,7 @@ where trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); - return Poll::Ready(Ok(WriteReady::Terminating)); - } - // No more messages from the client, but there are still some responses to wait for. - Poll::Ready(None) => { - trace!( - "poll_write: at eof, pending responses {}", - self.responses.len() - ); - ready!(self.poll_flush(cx))?; - return Poll::Ready(Ok(WriteReady::WaitingOnRead)); + return Poll::Ready(Ok(())); } // Still waiting for a message from the client. Poll::Pending => { @@ -298,7 +248,7 @@ where // if the state is still active, try read from and write to postgres. let message = self.poll_read(cx)?; let closing = self.poll_write(cx)?; - if let Poll::Ready(WriteReady::Terminating) = closing { + if let Poll::Ready(()) = closing { self.state = State::Closing; } diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index f7789632e3..9390af095e 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -10,7 +10,6 @@ use tracing::debug; use crate::client::{CachedTypeInfo, InnerClient}; use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::types::{Kind, Oid, Type}; use crate::{Column, Error, Statement, query, slice_iter}; @@ -29,7 +28,7 @@ async fn prepare_typecheck( types: &[Type], ) -> Result { let buf = encode(client, name, query, types)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send(FrontendMessage::Raw(buf))?; match responses.next().await? { Message::ParseComplete => {} diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index bda8e74d7d..4e32fd320e 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -1,12 +1,10 @@ use std::fmt; -use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use bytes::{BufMut, Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use futures_util::{Stream, ready}; -use pin_project_lite::pin_project; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use postgres_types2::{Format, ToSql, Type}; @@ -14,7 +12,6 @@ use tracing::debug; use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::types::IsNull; use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; @@ -48,20 +45,19 @@ where }; let responses = start(client, buf).await?; Ok(RowStream { - statement, responses, + statement, command_tag: None, status: ReadyForQueryStatus::Unknown, output_format: Format::Binary, - _p: PhantomPinned, }) } -pub async fn query_txt( - client: &mut InnerClient, +pub async fn query_txt<'a, S, I>( + client: &'a mut InnerClient, query: &str, params: I, -) -> Result +) -> Result, Error> where S: AsRef, I: IntoIterator>, @@ -108,7 +104,7 @@ where })?; // now read the responses - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send(FrontendMessage::Raw(buf))?; match responses.next().await? { Message::ParseComplete => {} @@ -149,17 +145,16 @@ where } Ok(RowStream { - statement: Statement::new_anonymous(parameters, columns), responses, + statement: Statement::new_anonymous(parameters, columns), command_tag: None, status: ReadyForQueryStatus::Unknown, output_format: Format::Text, - _p: PhantomPinned, }) } -async fn start(client: &mut InnerClient, buf: Bytes) -> Result { - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; +async fn start(client: &mut InnerClient, buf: Bytes) -> Result<&mut Responses, Error> { + let responses = client.send(FrontendMessage::Raw(buf))?; match responses.next().await? { Message::BindComplete => {} @@ -237,41 +232,37 @@ where } } -pin_project! { - /// A stream of table rows. - pub struct RowStream { - statement: Statement, - responses: Responses, - command_tag: Option, - output_format: Format, - status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, - } +/// A stream of table rows. +pub struct RowStream<'a> { + responses: &'a mut Responses, + output_format: Format, + pub statement: Statement, + pub command_tag: Option, + pub status: ReadyForQueryStatus, } -impl Stream for RowStream { +impl Stream for RowStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let this = self.get_mut(); loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new( this.statement.clone(), body, - *this.output_format, + this.output_format, )?))); } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { if let Ok(tag) = body.tag() { - *this.command_tag = Some(tag.to_string()); + this.command_tag = Some(tag.to_string()); } } Message::ReadyForQuery(status) => { - *this.status = status.into(); + this.status = status.into(); return Poll::Ready(None); } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), @@ -279,24 +270,3 @@ impl Stream for RowStream { } } } - -impl RowStream { - /// Returns information about the columns of data in the row. - pub fn columns(&self) -> &[Column] { - self.statement.columns() - } - - /// Returns the command tag of this query. - /// - /// This is only available after the stream has been exhausted. - pub fn command_tag(&self) -> Option { - self.command_tag.clone() - } - - /// Returns if the connection is ready for querying, with the status of the connection. - /// - /// This might be available only after the stream has been exhausted. - pub fn ready_status(&self) -> ReadyForQueryStatus { - self.status - } -} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs index f4d85d4100..321b4c5b43 100644 --- a/libs/proxy/tokio-postgres2/src/simple_query.rs +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -1,4 +1,3 @@ -use std::marker::PhantomPinned; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -13,7 +12,6 @@ use tracing::debug; use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; /// Information about a column of a single query row. @@ -33,20 +31,19 @@ impl SimpleColumn { } } -pub async fn simple_query( - client: &mut InnerClient, +pub async fn simple_query<'a>( + client: &'a mut InnerClient, query: &str, -) -> Result { +) -> Result, Error> { debug!("executing simple query: {}", query); let buf = encode(client, query)?; - let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send(FrontendMessage::Raw(buf))?; Ok(SimpleQueryStream { responses, columns: None, status: ReadyForQueryStatus::Unknown, - _p: PhantomPinned, }) } @@ -57,7 +54,7 @@ pub async fn batch_execute( debug!("executing statement batch: {}", query); let buf = encode(client, query)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send(FrontendMessage::Raw(buf))?; loop { match responses.next().await? { @@ -80,16 +77,14 @@ pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result { + responses: &'a mut Responses, columns: Option>, status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, } } -impl SimpleQueryStream { +impl SimpleQueryStream<'_> { /// Returns if the connection is ready for querying, with the status of the connection. /// /// This might be available only after the stream has been exhausted. @@ -98,7 +93,7 @@ impl SimpleQueryStream { } } -impl Stream for SimpleQueryStream { +impl Stream for SimpleQueryStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index ecb35cf60f..98a21ded5e 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -1,7 +1,6 @@ use postgres_protocol2::message::frontend; use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::query::RowStream; use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; @@ -24,10 +23,7 @@ impl Drop for Transaction<'_> { frontend::query("ROLLBACK", buf).unwrap(); buf.split().freeze() }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner().send(FrontendMessage::Raw(buf)); } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index dfaeedaeae..01dadaeb59 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -14,7 +14,9 @@ use hyper::http::{HeaderName, HeaderValue}; use hyper::{HeaderMap, Request, Response, StatusCode, header}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; -use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; +use postgres_client::{ + GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, +}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; @@ -1092,12 +1094,10 @@ async fn query_to_json( let query_start = Instant::now(); let query_params = data.params; - let mut row_stream = std::pin::pin!( - client - .query_raw_txt(&data.query, query_params) - .await - .map_err(SqlOverHttpError::Postgres)? - ); + let mut row_stream = client + .query_raw_txt(&data.query, query_params) + .await + .map_err(SqlOverHttpError::Postgres)?; let query_acknowledged = Instant::now(); // Manually drain the stream into a vector to leave row_stream hanging @@ -1118,10 +1118,15 @@ async fn query_to_json( } let query_resp_end = Instant::now(); - let ready = row_stream.ready_status(); + let RowStream { + statement, + command_tag, + status: ready, + .. + } = row_stream; // grab the command tag and number of rows affected - let command_tag = row_stream.command_tag().unwrap_or_default(); + let command_tag = command_tag.unwrap_or_default(); let mut command_tag_split = command_tag.split(' '); let command_tag_name = command_tag_split.next().unwrap_or_default(); let command_tag_count = if command_tag_name == "INSERT" { @@ -1142,11 +1147,11 @@ async fn query_to_json( "finished executing query" ); - let columns_len = row_stream.columns().len(); + let columns_len = statement.columns().len(); let mut fields = Vec::with_capacity(columns_len); let mut columns = Vec::with_capacity(columns_len); - for c in row_stream.columns() { + for c in statement.columns() { fields.push(json!({ "name": c.name().to_owned(), "dataTypeID": c.type_().oid(),