refactor connections

This commit is contained in:
Conrad Ludgate
2023-07-21 18:13:53 +01:00
parent 21c15c4285
commit 70b503f83b
4 changed files with 205 additions and 295 deletions

View File

@@ -2,6 +2,7 @@ use std::io::ErrorKind;
use std::sync::Arc;
use anyhow::bail;
use bytes::BufMut;
use fallible_iterator::FallibleIterator;
use futures::pin_mut;
use futures::StreamExt;
@@ -10,9 +11,12 @@ use hyper::body::HttpBody;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::{Body, HeaderMap, Request};
use postgres_protocol::message::backend::DataRowBody;
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::GenericClient;
@@ -364,21 +368,25 @@ async fn query_to_json<T: GenericClient>(
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, error::Error>
async fn query_raw_txt<'a, St, T>(
conn: &mut connection::Connection<St, T>,
query: String,
params: Vec<Option<String>>,
) -> Result<Vec<Column>, error::Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
St: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
let params = params.into_iter();
let params_len = params.len();
let params = params.into_iter();
let buf = self.inner.with_buf(|buf| {
{
let buf = &mut conn.buf;
// Parse, anonymous portal
frontend::parse("", query.as_ref(), std::iter::empty(), buf)
frontend::parse("", query.as_str(), std::iter::empty(), buf)
.map_err(error::Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
@@ -388,7 +396,7 @@ where
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
buf.put_slice(param.as_bytes());
Ok(postgres_protocol::IsNull::No)
}
None => Ok(postgres_protocol::IsNull::Yes),
@@ -409,49 +417,48 @@ where
frontend::execute("", 0, buf).map_err(error::Error::encode)?;
// Sync
frontend::sync(buf);
}
Ok(buf.split().freeze())
})?;
let mut responses = self
.inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
conn.send().await?;
conn.flush().await?;
// now read the responses
match responses.next().await? {
match conn.next_message().await? {
Message::ParseComplete => {}
_ => return Err(error::Error::unexpected_message()),
}
match responses.next().await? {
match conn.next_message().await? {
Message::BindComplete => {}
_ => return Err(error::Error::unexpected_message()),
}
let row_description = match responses.next().await? {
let row_description = match conn.next_message().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(error::Error::unexpected_message()),
};
// construct statement object
let parameters = vec![Type::UNKNOWN; params_len];
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
// NB: for some types that function may send a query to the server. At least in
// raw text mode we don't need that info and can skip this.
let type_ = get_type(&self.inner, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
while let Some(field) = it.next().map_err(error::Error::parse)? {
let type_ = Type::from_oid(field.type_oid());
// let column = Column::new(field.name().to_string(), type_, field);
columns.push(Column {
name: field.name().to_string(),
type_,
});
}
}
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
// let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
Ok(RowStream::new(statement, responses))
Ok(columns)
}
struct Column {
name: String,
type_: Option<Type>,
}
//
@@ -471,9 +478,9 @@ pub fn pg_text_row_to_json(
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, column.type_())?
pg_text_to_json(pg_value, Some(column.type_()))?
};
Ok((name.to_string(), json_value))
Ok((name, json_value))
});
if array_mode {
@@ -483,7 +490,55 @@ pub fn pg_text_row_to_json(
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
let obj = iter
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
//
// Convert postgres row with text-encoded values to JSON object
//
fn pg_text_row_to_json2(
row: &DataRowBody,
columns: &[Column],
raw_output: bool,
array_mode: bool,
) -> Result<Value, anyhow::Error> {
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
let iter = std::iter::zip(ranges, columns)
.enumerate()
.map(|(i, (range, column))| {
let name = &column.name;
let pg_value = range
.map(|r| {
std::str::from_utf8(&row.buffer()[r])
.map_err(|e| error::Error::from_sql(e.into(), i))
})
.transpose()?;
// let pg_value = row.as_text(i)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, column.type_.as_ref())?
};
Ok((name, json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
@@ -491,19 +546,22 @@ pub fn pg_text_row_to_json(
//
// Convert postgres text-encoded value to JSON value
//
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
pub fn pg_text_to_json(
pg_value: Option<&str>,
pg_type: Option<&Type>,
) -> Result<Value, anyhow::Error> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
if let Some(Kind::Array(elem_type)) = pg_type.map(|t| t.kind()) {
return pg_array_parse(val, Some(elem_type));
}
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
match pg_type {
Some(&Type::BOOL) => Ok(Value::Bool(val == "t")),
Some(&Type::INT2 | &Type::INT4) => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
Some(&Type::FLOAT4 | &Type::FLOAT8) => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
@@ -515,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
Some(&Type::JSON | &Type::JSONB) => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
@@ -530,13 +588,13 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
fn pg_array_parse(pg_array: &str, elem_type: Option<&Type>) -> Result<Value, anyhow::Error> {
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
}
fn _pg_array_parse(
pg_array: &str,
elem_type: &Type,
elem_type: Option<&Type>,
nested: bool,
) -> Result<(Value, usize), anyhow::Error> {
let mut pg_array_chr = pg_array.char_indices();
@@ -557,7 +615,7 @@ fn _pg_array_parse(
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
elem_type: Option<&Type>,
) -> Result<(), anyhow::Error> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
@@ -683,34 +741,43 @@ mod tests {
#[test]
fn test_atomic_types_parse() {
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
pg_text_to_json(Some("foo"), Some(&Type::TEXT)).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
pg_text_to_json(None, Some(&Type::TEXT)).unwrap(),
json!(null)
);
assert_eq!(
pg_text_to_json(Some("42"), Some(&Type::INT4)).unwrap(),
json!(42)
);
assert_eq!(
pg_text_to_json(Some("42"), Some(&Type::INT2)).unwrap(),
json!(42)
);
assert_eq!(
pg_text_to_json(Some("42"), Some(&Type::INT8)).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
pg_text_to_json(Some("42.42"), Some(&Type::FLOAT8)).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
pg_text_to_json(Some("42.42"), Some(&Type::FLOAT4)).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
pg_text_to_json(Some("NaN"), Some(&Type::FLOAT4)).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json(Some("Infinity"), Some(&Type::FLOAT4)).unwrap(),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json(Some("-Infinity"), Some(&Type::FLOAT4)).unwrap(),
json!("-Infinity")
);
@@ -720,7 +787,7 @@ mod tests {
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
&Type::JSONB
Some(&Type::JSONB)
)
.unwrap(),
json
@@ -730,7 +797,7 @@ mod tests {
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
pg_array_parse(pg_arr, Some(&Type::TEXT)).unwrap()
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
@@ -753,7 +820,7 @@ mod tests {
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
pg_array_parse(pg_arr, Some(&Type::BOOL)).unwrap()
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
@@ -770,7 +837,7 @@ mod tests {
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
pg_array_parse(pg_arr, Some(ty)).unwrap()
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
@@ -798,7 +865,7 @@ mod tests {
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
pg_array_parse(pg_arr, Some(&Type::INT2)).unwrap()
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),

View File

@@ -17,7 +17,7 @@ pub enum BackendMessage {
Async(backend::Message),
}
pub struct BackendMessages(BytesMut);
pub struct BackendMessages(pub BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {

View File

@@ -3,12 +3,13 @@ use super::error::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{stream::FusedStream, Sink, Stream, StreamExt};
use futures::SinkExt;
use futures::{Sink, StreamExt};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::future::poll_fn;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
use tokio_util::codec::Framed;
@@ -16,8 +17,6 @@ use tracing::trace;
pub enum RequestMessages {
Single(FrontendMessage),
// CopyIn(CopyInReceiver),
// CopyBoth(CopyBothReceiver),
}
pub struct Request {
@@ -29,12 +28,12 @@ pub struct Response {
sender: mpsc::Sender<BackendMessages>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
// #[derive(PartialEq, Debug)]
// enum State {
// Active,
// Terminating,
// Closing,
// }
/// A connection to a PostgreSQL database.
///
@@ -49,11 +48,12 @@ pub struct Connection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
// receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
pending_responses: VecDeque<(BackendMessages, bool)>,
pub buf: BytesMut,
// responses: VecDeque<Response>,
// state: State,
}
impl<S, T> Connection<S, T>
@@ -63,18 +63,49 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
pending_responses: VecDeque<(BackendMessages, bool)>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
// receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
receiver,
// receiver,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
buf: BytesMut::new(),
// responses: VecDeque::new(),
// state: State::Active,
}
}
pub async fn send(&mut self) -> Result<(), Error> {
poll_fn(|cx| self.poll_send(cx)).await?;
let request = FrontendMessage::Raw(self.buf.split().freeze());
self.stream.start_send_unpin(request).map_err(Error::io)
}
pub async fn flush(&mut self) -> Result<(), Error> {
poll_fn(|cx| self.poll_flush(cx)).await
}
pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> {
match self.pending_responses.pop_front() {
Some((a, b)) => Ok((a, b)),
None => poll_fn(|cx| self.poll_read(cx)).await,
}
}
pub async fn next_message(&mut self) -> Result<Message, Error> {
loop {
let (mut messages, complete) = self.next_response().await?;
if let Some(message) = messages.next().map_err(Error::parse)? {
self.pending_responses.push_front((messages, complete));
break Ok(message);
}
if complete {
break Err(Error::unexpected_message());
}
}
}
@@ -82,33 +113,19 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
self.stream
.poll_next_unpin(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<()>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<(BackendMessages, bool), Error>> {
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Ok(None);
}
let message = match ready!(self.poll_response(cx)?) {
Some(message) => message,
None => return Poll::Ready(Err(Error::closed())),
};
let (mut messages, request_complete) = match message {
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
// TODO: log this
@@ -138,169 +155,12 @@ where
BackendMessage::Normal {
messages,
request_complete,
} => (messages, request_complete),
} => return Poll::Ready(Ok((messages, request_complete))),
};
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
match response.sender.poll_ready(cx) {
Poll::Ready(Ok(())) => {
let _ = response.sender.start_send(messages);
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Ready(Err(_)) => {
// we need to keep paging through the rest of the messages even if the receiver's hung up
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Ok(None);
}
}
}
}
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_terminated() {
return Poll::Ready(None);
}
match self.receiver.poll_next_unpin(cx) {
Poll::Ready(Some(request)) => {
trace!("polled new request");
self.responses.push_back(Response {
sender: request.sender,
});
Poll::Ready(Some(request.messages))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
return Ok(false);
}
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
return Ok(true);
}
Poll::Pending => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
} // RequestMessages::CopyIn(mut receiver) => {
// let message = match receiver.poll_next_unpin(cx) {
// Poll::Ready(Some(message)) => message,
// Poll::Ready(None) => {
// trace!("poll_write: finished copy_in request");
// continue;
// }
// Poll::Pending => {
// trace!("poll_write: waiting on copy_in stream");
// self.pending_request = Some(RequestMessages::CopyIn(receiver));
// return Ok(true);
// }
// };
// Pin::new(&mut self.stream)
// .start_send(message)
// .map_err(Error::io)?;
// self.pending_request = Some(RequestMessages::CopyIn(receiver));
// }
// RequestMessages::CopyBoth(mut receiver) => {
// let message = match receiver.poll_next_unpin(cx) {
// Poll::Ready(Some(message)) => message,
// Poll::Ready(None) => {
// trace!("poll_write: finished copy_both request");
// continue;
// }
// Poll::Pending => {
// trace!("poll_write: waiting on copy_both stream");
// self.pending_request = Some(RequestMessages::CopyBoth(receiver));
// return Ok(true);
// }
// };
// Pin::new(&mut self.stream)
// .start_send(message)
// .map_err(Error::io)?;
// self.pending_request = Some(RequestMessages::CopyBoth(receiver));
// }
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
@@ -321,40 +181,17 @@ where
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<(), Error>>> {
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_ready_unpin(cx).map_err(Error::io)
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_flush_unpin(cx).map_err(Error::io)
}
}
// impl<S, T> Future for Connection<S, T>
// where
// S: AsyncRead + AsyncWrite + Unpin,
// T: AsyncRead + AsyncWrite + Unpin,
// {
// type Output = Result<(), Error>;
// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
// while let Some(message) = ready!(self.poll_message(cx)?) {
// if let AsyncMessage::Notice(notice) = message {
// info!("{}: {}", notice.severity(), notice.message());
// }
// }
// Poll::Ready(Ok(()))
// }
// }

View File

@@ -1,13 +1,14 @@
use std::{error, fmt, io};
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use tokio_postgres::error::{SqlState, ErrorPosition};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use tokio_postgres::error::{ErrorPosition, SqlState};
#[derive(Debug, PartialEq)]
enum Kind {
Io,
UnexpectedMessage,
FromSql(usize),
Closed,
Db,
Parse,
@@ -36,6 +37,7 @@ impl fmt::Display for Error {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Closed => fmt.write_str("connection closed")?,
Kind::Db => fmt.write_str("db error")?,
Kind::Parse => fmt.write_str("error parsing response from server")?,
@@ -91,6 +93,10 @@ impl Error {
}
}
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::FromSql(idx), Some(e))
}
pub(crate) fn closed() -> Error {
Error::new(Kind::Closed, None)
}