mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
refactor connections
This commit is contained in:
@@ -2,6 +2,7 @@ use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::BufMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
@@ -10,9 +11,12 @@ use hyper::body::HttpBody;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use postgres_protocol::message::backend::DataRowBody;
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::GenericClient;
|
||||
@@ -364,21 +368,25 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
|
||||
/// to save a roundtrip
|
||||
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, error::Error>
|
||||
async fn query_raw_txt<'a, St, T>(
|
||||
conn: &mut connection::Connection<St, T>,
|
||||
query: String,
|
||||
params: Vec<Option<String>>,
|
||||
) -> Result<Vec<Column>, error::Error>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
St: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
|
||||
let params = params.into_iter();
|
||||
let params_len = params.len();
|
||||
let params = params.into_iter();
|
||||
|
||||
let buf = self.inner.with_buf(|buf| {
|
||||
{
|
||||
let buf = &mut conn.buf;
|
||||
// Parse, anonymous portal
|
||||
frontend::parse("", query.as_ref(), std::iter::empty(), buf)
|
||||
frontend::parse("", query.as_str(), std::iter::empty(), buf)
|
||||
.map_err(error::Error::encode)?;
|
||||
// Bind, pass params as text, retrieve as binary
|
||||
match frontend::bind(
|
||||
@@ -388,7 +396,7 @@ where
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
buf.put_slice(param.as_bytes());
|
||||
Ok(postgres_protocol::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol::IsNull::Yes),
|
||||
@@ -409,49 +417,48 @@ where
|
||||
frontend::execute("", 0, buf).map_err(error::Error::encode)?;
|
||||
// Sync
|
||||
frontend::sync(buf);
|
||||
}
|
||||
|
||||
Ok(buf.split().freeze())
|
||||
})?;
|
||||
|
||||
let mut responses = self
|
||||
.inner
|
||||
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
conn.send().await?;
|
||||
conn.flush().await?;
|
||||
|
||||
// now read the responses
|
||||
|
||||
match responses.next().await? {
|
||||
match conn.next_message().await? {
|
||||
Message::ParseComplete => {}
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
}
|
||||
match responses.next().await? {
|
||||
match conn.next_message().await? {
|
||||
Message::BindComplete => {}
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
}
|
||||
let row_description = match responses.next().await? {
|
||||
let row_description = match conn.next_message().await? {
|
||||
Message::RowDescription(body) => Some(body),
|
||||
Message::NoData => None,
|
||||
_ => return Err(error::Error::unexpected_message()),
|
||||
};
|
||||
|
||||
// construct statement object
|
||||
|
||||
let parameters = vec![Type::UNKNOWN; params_len];
|
||||
|
||||
let mut columns = vec![];
|
||||
if let Some(row_description) = row_description {
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(Error::parse)? {
|
||||
// NB: for some types that function may send a query to the server. At least in
|
||||
// raw text mode we don't need that info and can skip this.
|
||||
let type_ = get_type(&self.inner, field.type_oid()).await?;
|
||||
let column = Column::new(field.name().to_string(), type_, field);
|
||||
columns.push(column);
|
||||
while let Some(field) = it.next().map_err(error::Error::parse)? {
|
||||
let type_ = Type::from_oid(field.type_oid());
|
||||
// let column = Column::new(field.name().to_string(), type_, field);
|
||||
columns.push(Column {
|
||||
name: field.name().to_string(),
|
||||
type_,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
|
||||
// let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
|
||||
|
||||
Ok(RowStream::new(statement, responses))
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
struct Column {
|
||||
name: String,
|
||||
type_: Option<Type>,
|
||||
}
|
||||
|
||||
//
|
||||
@@ -471,9 +478,9 @@ pub fn pg_text_row_to_json(
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, column.type_())?
|
||||
pg_text_to_json(pg_value, Some(column.type_()))?
|
||||
};
|
||||
Ok((name.to_string(), json_value))
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
@@ -483,7 +490,55 @@ pub fn pg_text_row_to_json(
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
fn pg_text_row_to_json2(
|
||||
row: &DataRowBody,
|
||||
columns: &[Column],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
|
||||
let iter = std::iter::zip(ranges, columns)
|
||||
.enumerate()
|
||||
.map(|(i, (range, column))| {
|
||||
let name = &column.name;
|
||||
let pg_value = range
|
||||
.map(|r| {
|
||||
std::str::from_utf8(&row.buffer()[r])
|
||||
.map_err(|e| error::Error::from_sql(e.into(), i))
|
||||
})
|
||||
.transpose()?;
|
||||
// let pg_value = row.as_text(i)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, column.type_.as_ref())?
|
||||
};
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
@@ -491,19 +546,22 @@ pub fn pg_text_row_to_json(
|
||||
//
|
||||
// Convert postgres text-encoded value to JSON value
|
||||
//
|
||||
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
pub fn pg_text_to_json(
|
||||
pg_value: Option<&str>,
|
||||
pg_type: Option<&Type>,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
if let Some(Kind::Array(elem_type)) = pg_type.map(|t| t.kind()) {
|
||||
return pg_array_parse(val, Some(elem_type));
|
||||
}
|
||||
|
||||
match *pg_type {
|
||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
Type::INT2 | Type::INT4 => {
|
||||
match pg_type {
|
||||
Some(&Type::BOOL) => Ok(Value::Bool(val == "t")),
|
||||
Some(&Type::INT2 | &Type::INT4) => {
|
||||
let val = val.parse::<i32>()?;
|
||||
Ok(Value::Number(serde_json::Number::from(val)))
|
||||
}
|
||||
Type::FLOAT4 | Type::FLOAT8 => {
|
||||
Some(&Type::FLOAT4 | &Type::FLOAT8) => {
|
||||
let fval = val.parse::<f64>()?;
|
||||
let num = serde_json::Number::from_f64(fval);
|
||||
if let Some(num) = num {
|
||||
@@ -515,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
|
||||
Ok(Value::String(val.to_string()))
|
||||
}
|
||||
}
|
||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
Some(&Type::JSON | &Type::JSONB) => Ok(serde_json::from_str(val)?),
|
||||
_ => Ok(Value::String(val.to_string())),
|
||||
}
|
||||
} else {
|
||||
@@ -530,13 +588,13 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
|
||||
// values. Unlike postgres we don't check that all nested arrays have the same
|
||||
// dimensions, we just return them as is.
|
||||
//
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
fn pg_array_parse(pg_array: &str, elem_type: Option<&Type>) -> Result<Value, anyhow::Error> {
|
||||
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
|
||||
}
|
||||
|
||||
fn _pg_array_parse(
|
||||
pg_array: &str,
|
||||
elem_type: &Type,
|
||||
elem_type: Option<&Type>,
|
||||
nested: bool,
|
||||
) -> Result<(Value, usize), anyhow::Error> {
|
||||
let mut pg_array_chr = pg_array.char_indices();
|
||||
@@ -557,7 +615,7 @@ fn _pg_array_parse(
|
||||
fn push_checked(
|
||||
entry: &mut String,
|
||||
entries: &mut Vec<Value>,
|
||||
elem_type: &Type,
|
||||
elem_type: Option<&Type>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
if !entry.is_empty() {
|
||||
// While in usual postgres response we get nulls as None and everything else
|
||||
@@ -683,34 +741,43 @@ mod tests {
|
||||
#[test]
|
||||
fn test_atomic_types_parse() {
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
|
||||
pg_text_to_json(Some("foo"), Some(&Type::TEXT)).unwrap(),
|
||||
json!("foo")
|
||||
);
|
||||
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
|
||||
pg_text_to_json(None, Some(&Type::TEXT)).unwrap(),
|
||||
json!(null)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), Some(&Type::INT4)).unwrap(),
|
||||
json!(42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), Some(&Type::INT2)).unwrap(),
|
||||
json!(42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), Some(&Type::INT8)).unwrap(),
|
||||
json!("42")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
|
||||
pg_text_to_json(Some("42.42"), Some(&Type::FLOAT8)).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
|
||||
pg_text_to_json(Some("42.42"), Some(&Type::FLOAT4)).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
|
||||
pg_text_to_json(Some("NaN"), Some(&Type::FLOAT4)).unwrap(),
|
||||
json!("NaN")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
|
||||
pg_text_to_json(Some("Infinity"), Some(&Type::FLOAT4)).unwrap(),
|
||||
json!("Infinity")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
|
||||
pg_text_to_json(Some("-Infinity"), Some(&Type::FLOAT4)).unwrap(),
|
||||
json!("-Infinity")
|
||||
);
|
||||
|
||||
@@ -720,7 +787,7 @@ mod tests {
|
||||
assert_eq!(
|
||||
pg_text_to_json(
|
||||
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
|
||||
&Type::JSONB
|
||||
Some(&Type::JSONB)
|
||||
)
|
||||
.unwrap(),
|
||||
json
|
||||
@@ -730,7 +797,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_array_parse_text() {
|
||||
fn pt(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
|
||||
pg_array_parse(pg_arr, Some(&Type::TEXT)).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
|
||||
@@ -753,7 +820,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_array_parse_bool() {
|
||||
fn pb(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
|
||||
pg_array_parse(pg_arr, Some(&Type::BOOL)).unwrap()
|
||||
}
|
||||
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
|
||||
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
|
||||
@@ -770,7 +837,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_array_parse_numbers() {
|
||||
fn pn(pg_arr: &str, ty: &Type) -> Value {
|
||||
pg_array_parse(pg_arr, ty).unwrap()
|
||||
pg_array_parse(pg_arr, Some(ty)).unwrap()
|
||||
}
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
|
||||
@@ -798,7 +865,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_array_with_decoration() {
|
||||
fn p(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::INT2).unwrap()
|
||||
pg_array_parse(pg_arr, Some(&Type::INT2)).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
|
||||
|
||||
@@ -17,7 +17,7 @@ pub enum BackendMessage {
|
||||
Async(backend::Message),
|
||||
}
|
||||
|
||||
pub struct BackendMessages(BytesMut);
|
||||
pub struct BackendMessages(pub BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
|
||||
@@ -3,12 +3,13 @@ use super::error::Error;
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{stream::FusedStream, Sink, Stream, StreamExt};
|
||||
use futures::SinkExt;
|
||||
use futures::{Sink, StreamExt};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::poll_fn;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||
use tokio_util::codec::Framed;
|
||||
@@ -16,8 +17,6 @@ use tracing::trace;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
// CopyIn(CopyInReceiver),
|
||||
// CopyBoth(CopyBothReceiver),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
@@ -29,12 +28,12 @@ pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
enum State {
|
||||
Active,
|
||||
Terminating,
|
||||
Closing,
|
||||
}
|
||||
// #[derive(PartialEq, Debug)]
|
||||
// enum State {
|
||||
// Active,
|
||||
// Terminating,
|
||||
// Closing,
|
||||
// }
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
///
|
||||
@@ -49,11 +48,12 @@ pub struct Connection<S, T> {
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
/// HACK: we need this in the Neon Proxy to forward params.
|
||||
pub parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
// receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
pending_responses: VecDeque<(BackendMessages, bool)>,
|
||||
pub buf: BytesMut,
|
||||
// responses: VecDeque<Response>,
|
||||
// state: State,
|
||||
}
|
||||
|
||||
impl<S, T> Connection<S, T>
|
||||
@@ -63,18 +63,49 @@ where
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
pending_responses: VecDeque<(BackendMessages, bool)>,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
// receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S, T> {
|
||||
Connection {
|
||||
stream,
|
||||
parameters,
|
||||
receiver,
|
||||
// receiver,
|
||||
pending_request: None,
|
||||
pending_responses,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
buf: BytesMut::new(),
|
||||
// responses: VecDeque::new(),
|
||||
// state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||
let request = FrontendMessage::Raw(self.buf.split().freeze());
|
||||
self.stream.start_send_unpin(request).map_err(Error::io)
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_flush(cx)).await
|
||||
}
|
||||
|
||||
pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> {
|
||||
match self.pending_responses.pop_front() {
|
||||
Some((a, b)) => Ok((a, b)),
|
||||
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||
loop {
|
||||
let (mut messages, complete) = self.next_response().await?;
|
||||
if let Some(message) = messages.next().map_err(Error::parse)? {
|
||||
self.pending_responses.push_front((messages, complete));
|
||||
break Ok(message);
|
||||
}
|
||||
if complete {
|
||||
break Err(Error::unexpected_message());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,33 +113,19 @@ where
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
if let Some(message) = self.pending_responses.pop_front() {
|
||||
trace!("retrying pending response");
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.poll_next(cx)
|
||||
self.stream
|
||||
.poll_next_unpin(cx)
|
||||
.map(|o| o.map(|r| r.map_err(Error::io)))
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<()>, Error> {
|
||||
if self.state != State::Active {
|
||||
trace!("poll_read: done");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<(BackendMessages, bool), Error>> {
|
||||
loop {
|
||||
let message = match self.poll_response(cx)? {
|
||||
Poll::Ready(Some(message)) => message,
|
||||
Poll::Ready(None) => return Err(Error::closed()),
|
||||
Poll::Pending => {
|
||||
trace!("poll_read: waiting on response");
|
||||
return Ok(None);
|
||||
}
|
||||
let message = match ready!(self.poll_response(cx)?) {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
};
|
||||
|
||||
let (mut messages, request_complete) = match message {
|
||||
match message {
|
||||
BackendMessage::Async(Message::NoticeResponse(body)) => {
|
||||
// TODO: log this
|
||||
|
||||
@@ -138,169 +155,12 @@ where
|
||||
BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
} => (messages, request_complete),
|
||||
} => return Poll::Ready(Ok((messages, request_complete))),
|
||||
};
|
||||
|
||||
let mut response = match self.responses.pop_front() {
|
||||
Some(response) => response,
|
||||
None => match messages.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
},
|
||||
};
|
||||
|
||||
match response.sender.poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let _ = response.sender.start_send(messages);
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// we need to keep paging through the rest of the messages even if the receiver's hung up
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.responses.push_front(response);
|
||||
self.pending_responses.push_back(BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
});
|
||||
trace!("poll_read: waiting on sender");
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
|
||||
if let Some(messages) = self.pending_request.take() {
|
||||
trace!("retrying pending request");
|
||||
return Poll::Ready(Some(messages));
|
||||
}
|
||||
|
||||
if self.receiver.is_terminated() {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
match self.receiver.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(request)) => {
|
||||
trace!("polled new request");
|
||||
self.responses.push_back(Response {
|
||||
sender: request.sender,
|
||||
});
|
||||
Poll::Ready(Some(request.messages))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
|
||||
loop {
|
||||
if self.state == State::Closing {
|
||||
trace!("poll_write: done");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if Pin::new(&mut self.stream)
|
||||
.poll_ready(cx)
|
||||
.map_err(Error::io)?
|
||||
.is_pending()
|
||||
{
|
||||
trace!("poll_write: waiting on socket");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let request = match self.poll_request(cx) {
|
||||
Poll::Ready(Some(request)) => request,
|
||||
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
self.state = State::Terminating;
|
||||
let mut request = BytesMut::new();
|
||||
frontend::terminate(&mut request);
|
||||
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
trace!(
|
||||
"poll_write: at eof, pending responses {}",
|
||||
self.responses.len()
|
||||
);
|
||||
return Ok(true);
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_write: waiting on request");
|
||||
return Ok(true);
|
||||
}
|
||||
};
|
||||
|
||||
match request {
|
||||
RequestMessages::Single(request) => {
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
if self.state == State::Terminating {
|
||||
trace!("poll_write: sent eof, closing");
|
||||
self.state = State::Closing;
|
||||
}
|
||||
} // RequestMessages::CopyIn(mut receiver) => {
|
||||
// let message = match receiver.poll_next_unpin(cx) {
|
||||
// Poll::Ready(Some(message)) => message,
|
||||
// Poll::Ready(None) => {
|
||||
// trace!("poll_write: finished copy_in request");
|
||||
// continue;
|
||||
// }
|
||||
// Poll::Pending => {
|
||||
// trace!("poll_write: waiting on copy_in stream");
|
||||
// self.pending_request = Some(RequestMessages::CopyIn(receiver));
|
||||
// return Ok(true);
|
||||
// }
|
||||
// };
|
||||
// Pin::new(&mut self.stream)
|
||||
// .start_send(message)
|
||||
// .map_err(Error::io)?;
|
||||
// self.pending_request = Some(RequestMessages::CopyIn(receiver));
|
||||
// }
|
||||
// RequestMessages::CopyBoth(mut receiver) => {
|
||||
// let message = match receiver.poll_next_unpin(cx) {
|
||||
// Poll::Ready(Some(message)) => message,
|
||||
// Poll::Ready(None) => {
|
||||
// trace!("poll_write: finished copy_both request");
|
||||
// continue;
|
||||
// }
|
||||
// Poll::Pending => {
|
||||
// trace!("poll_write: waiting on copy_both stream");
|
||||
// self.pending_request = Some(RequestMessages::CopyBoth(receiver));
|
||||
// return Ok(true);
|
||||
// }
|
||||
// };
|
||||
// Pin::new(&mut self.stream)
|
||||
// .start_send(message)
|
||||
// .map_err(Error::io)?;
|
||||
// self.pending_request = Some(RequestMessages::CopyBoth(receiver));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_flush(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => trace!("poll_flush: flushed"),
|
||||
Poll::Pending => trace!("poll_flush: waiting on socket"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.state != State::Closing {
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_close(cx)
|
||||
.map_err(Error::io)?
|
||||
@@ -321,40 +181,17 @@ where
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
||||
/// Polls for asynchronous messages from the server.
|
||||
///
|
||||
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
|
||||
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
|
||||
pub fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<(), Error>>> {
|
||||
let message = self.poll_read(cx)?;
|
||||
let want_flush = self.poll_write(cx)?;
|
||||
if want_flush {
|
||||
self.poll_flush(cx)?;
|
||||
}
|
||||
match message {
|
||||
Some(message) => Poll::Ready(Some(Ok(message))),
|
||||
None => match self.poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
},
|
||||
}
|
||||
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
}
|
||||
|
||||
// impl<S, T> Future for Connection<S, T>
|
||||
// where
|
||||
// S: AsyncRead + AsyncWrite + Unpin,
|
||||
// T: AsyncRead + AsyncWrite + Unpin,
|
||||
// {
|
||||
// type Output = Result<(), Error>;
|
||||
|
||||
// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
// while let Some(message) = ready!(self.poll_message(cx)?) {
|
||||
// if let AsyncMessage::Notice(notice) = message {
|
||||
// info!("{}: {}", notice.severity(), notice.message());
|
||||
// }
|
||||
// }
|
||||
// Poll::Ready(Ok(()))
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use std::{error, fmt, io};
|
||||
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use tokio_postgres::error::{SqlState, ErrorPosition};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use tokio_postgres::error::{ErrorPosition, SqlState};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Kind {
|
||||
Io,
|
||||
UnexpectedMessage,
|
||||
FromSql(usize),
|
||||
Closed,
|
||||
Db,
|
||||
Parse,
|
||||
@@ -36,6 +37,7 @@ impl fmt::Display for Error {
|
||||
match &self.0.kind {
|
||||
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
Kind::Db => fmt.write_str("db error")?,
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
@@ -91,6 +93,10 @@ impl Error {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||
Error::new(Kind::FromSql(idx), Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user