diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index 0aa5c77e22..f478717e0d 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -33,10 +33,14 @@ pub struct Response { #[derive(PartialEq, Debug)] enum State { Active, - Terminating, Closing, } +enum WriteReady { + Terminating, + WaitingOnRead, +} + /// A connection to a PostgreSQL database. /// /// This is one half of what is returned when a new connection is established. It performs the actual IO with the @@ -51,7 +55,6 @@ pub struct Connection { /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, - pending_request: Option, pending_responses: VecDeque, responses: VecDeque, state: State, @@ -72,7 +75,6 @@ where stream, parameters, receiver, - pending_request: None, pending_responses, responses: VecDeque::new(), state: State::Active, @@ -93,26 +95,23 @@ where .map(|o| o.map(|r| r.map_err(Error::io))) } - fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { - if self.state != State::Active { - trace!("poll_read: done"); - return Ok(None); - } - + /// Read and process messages from the connection to postgres. + /// client <- postgres + fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { loop { let message = match self.poll_response(cx)? { Poll::Ready(Some(message)) => message, - Poll::Ready(None) => return Err(Error::closed()), + Poll::Ready(None) => return Poll::Ready(Err(Error::closed())), Poll::Pending => { trace!("poll_read: waiting on response"); - return Ok(None); + return Poll::Pending; } }; let (mut messages, request_complete) = match message { BackendMessage::Async(Message::NoticeResponse(body)) => { let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; - return Ok(Some(AsyncMessage::Notice(error))); + return Poll::Ready(Ok(AsyncMessage::Notice(error))); } BackendMessage::Async(Message::NotificationResponse(body)) => { let notification = Notification { @@ -120,7 +119,7 @@ where channel: body.channel().map_err(Error::parse)?.to_string(), payload: body.message().map_err(Error::parse)?.to_string(), }; - return Ok(Some(AsyncMessage::Notification(notification))); + return Poll::Ready(Ok(AsyncMessage::Notification(notification))); } BackendMessage::Async(Message::ParameterStatus(body)) => { self.parameters.insert( @@ -139,8 +138,10 @@ where let mut response = match self.responses.pop_front() { Some(response) => response, None => match messages.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), + Some(Message::ErrorResponse(error)) => { + return Poll::Ready(Err(Error::db(error))) + } + _ => return Poll::Ready(Err(Error::unexpected_message())), }, }; @@ -164,18 +165,14 @@ where request_complete, }); trace!("poll_read: waiting on sender"); - return Ok(None); + return Poll::Pending; } } } } + /// Fetch the next client request and enqueue the response sender. fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(messages) = self.pending_request.take() { - trace!("retrying pending request"); - return Poll::Ready(Some(messages)); - } - if self.receiver.is_closed() { return Poll::Ready(None); } @@ -193,74 +190,80 @@ where } } - fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { + /// Process client requests and write them to the postgres connection, flushing if necessary. + /// client -> postgres + fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - if self.state == State::Closing { - trace!("poll_write: done"); - return Ok(false); - } - if Pin::new(&mut self.stream) .poll_ready(cx) .map_err(Error::io)? .is_pending() { trace!("poll_write: waiting on socket"); - return Ok(false); + + // poll_ready is self-flushing. + return Poll::Pending; } - let request = match self.poll_request(cx) { - Poll::Ready(Some(request)) => request, - Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + match self.poll_request(cx) { + // send the message to postgres + Poll::Ready(Some(RequestMessages::Single(request))) => { + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + } + // No more messages from the client, and no more responses to wait for. + // Send a terminate message to postgres + Poll::Ready(None) if self.responses.is_empty() => { trace!("poll_write: at eof, terminating"); - self.state = State::Terminating; let mut request = BytesMut::new(); frontend::terminate(&mut request); - RequestMessages::Single(FrontendMessage::Raw(request.freeze())) + let request = FrontendMessage::Raw(request.freeze()); + + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + + trace!("poll_write: sent eof, closing"); + trace!("poll_write: done"); + return Poll::Ready(Ok(WriteReady::Terminating)); } + // No more messages from the client, but there are still some responses to wait for. Poll::Ready(None) => { trace!( "poll_write: at eof, pending responses {}", self.responses.len() ); - return Ok(true); + ready!(self.poll_flush(cx))?; + return Poll::Ready(Ok(WriteReady::WaitingOnRead)); } + // Still waiting for a message from the client. Poll::Pending => { trace!("poll_write: waiting on request"); - return Ok(true); - } - }; - - match request { - RequestMessages::Single(request) => { - Pin::new(&mut self.stream) - .start_send(request) - .map_err(Error::io)?; - if self.state == State::Terminating { - trace!("poll_write: sent eof, closing"); - self.state = State::Closing; - } + ready!(self.poll_flush(cx))?; + return Poll::Pending; } } } } - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { match Pin::new(&mut self.stream) .poll_flush(cx) .map_err(Error::io)? { - Poll::Ready(()) => trace!("poll_flush: flushed"), - Poll::Pending => trace!("poll_flush: waiting on socket"), + Poll::Ready(()) => { + trace!("poll_flush: flushed"); + Poll::Ready(Ok(())) + } + Poll::Pending => { + trace!("poll_flush: waiting on socket"); + Poll::Pending + } } - Ok(()) } fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.state != State::Closing { - return Poll::Pending; - } - match Pin::new(&mut self.stream) .poll_close(cx) .map_err(Error::io)? @@ -289,18 +292,30 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll>> { - let message = self.poll_read(cx)?; - let want_flush = self.poll_write(cx)?; - if want_flush { - self.poll_flush(cx)?; + if self.state != State::Closing { + // if the state is still active, try read from and write to postgres. + let message = self.poll_read(cx)?; + let closing = self.poll_write(cx)?; + if let Poll::Ready(WriteReady::Terminating) = closing { + self.state = State::Closing; + } + + if let Poll::Ready(message) = message { + return Poll::Ready(Some(Ok(message))); + } + + // poll_read returned Pending. + // poll_write returned Pending or Ready(WriteReady::WaitingOnRead). + // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres. + if self.state != State::Closing { + return Poll::Pending; + } } - match message { - Some(message) => Poll::Ready(Some(Ok(message))), - None => match self.poll_shutdown(cx) { - Poll::Ready(Ok(())) => Poll::Ready(None), - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), - Poll::Pending => Poll::Pending, - }, + + match self.poll_shutdown(cx) { + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, } } }