From 70b503f83b8b3fcb596f864bda465bebf78e1690 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 21 Jul 2023 18:13:53 +0100 Subject: [PATCH] refactor connections --- proxy/src/http/sql_over_http.rs | 183 +++++++++---- proxy/src/http/sql_over_http/codec.rs | 2 +- proxy/src/http/sql_over_http/connection.rs | 305 +++++---------------- proxy/src/http/sql_over_http/error.rs | 10 +- 4 files changed, 205 insertions(+), 295 deletions(-) diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index 8a04f43516..c135e7c164 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -2,6 +2,7 @@ use std::io::ErrorKind; use std::sync::Arc; use anyhow::bail; +use bytes::BufMut; use fallible_iterator::FallibleIterator; use futures::pin_mut; use futures::StreamExt; @@ -10,9 +11,12 @@ use hyper::body::HttpBody; use hyper::http::HeaderName; use hyper::http::HeaderValue; use hyper::{Body, HeaderMap, Request}; +use postgres_protocol::message::backend::DataRowBody; use serde_json::json; use serde_json::Map; use serde_json::Value; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::GenericClient; @@ -364,21 +368,25 @@ async fn query_to_json( /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip -pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result +async fn query_raw_txt<'a, St, T>( + conn: &mut connection::Connection, + query: String, + params: Vec>, +) -> Result, error::Error> where - S: AsRef, - I: IntoIterator>, - I::IntoIter: ExactSizeIterator, + St: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, { use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; - let params = params.into_iter(); let params_len = params.len(); + let params = params.into_iter(); - let buf = self.inner.with_buf(|buf| { + { + let buf = &mut conn.buf; // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf) + frontend::parse("", query.as_str(), std::iter::empty(), buf) .map_err(error::Error::encode)?; // Bind, pass params as text, retrieve as binary match frontend::bind( @@ -388,7 +396,7 @@ where params, |param, buf| match param { Some(param) => { - buf.put_slice(param.as_ref().as_bytes()); + buf.put_slice(param.as_bytes()); Ok(postgres_protocol::IsNull::No) } None => Ok(postgres_protocol::IsNull::Yes), @@ -409,49 +417,48 @@ where frontend::execute("", 0, buf).map_err(error::Error::encode)?; // Sync frontend::sync(buf); + } - Ok(buf.split().freeze()) - })?; - - let mut responses = self - .inner - .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + conn.send().await?; + conn.flush().await?; // now read the responses - match responses.next().await? { + match conn.next_message().await? { Message::ParseComplete => {} _ => return Err(error::Error::unexpected_message()), } - match responses.next().await? { + match conn.next_message().await? { Message::BindComplete => {} _ => return Err(error::Error::unexpected_message()), } - let row_description = match responses.next().await? { + let row_description = match conn.next_message().await? { Message::RowDescription(body) => Some(body), Message::NoData => None, _ => return Err(error::Error::unexpected_message()), }; - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - let mut columns = vec![]; if let Some(row_description) = row_description { let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); + while let Some(field) = it.next().map_err(error::Error::parse)? { + let type_ = Type::from_oid(field.type_oid()); + // let column = Column::new(field.name().to_string(), type_, field); + columns.push(Column { + name: field.name().to_string(), + type_, + }); } } - let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + // let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); - Ok(RowStream::new(statement, responses)) + Ok(columns) +} + +struct Column { + name: String, + type_: Option, } // @@ -471,9 +478,9 @@ pub fn pg_text_row_to_json( None => Value::Null, } } else { - pg_text_to_json(pg_value, column.type_())? + pg_text_to_json(pg_value, Some(column.type_()))? }; - Ok((name.to_string(), json_value)) + Ok((name, json_value)) }); if array_mode { @@ -483,7 +490,55 @@ pub fn pg_text_row_to_json( .collect::, anyhow::Error>>()?; Ok(Value::Array(arr)) } else { - let obj = iter.collect::, anyhow::Error>>()?; + let obj = iter + .map(|r| r.map(|(key, val)| (key.to_owned(), val))) + .collect::, anyhow::Error>>()?; + Ok(Value::Object(obj)) + } +} + +// +// Convert postgres row with text-encoded values to JSON object +// +fn pg_text_row_to_json2( + row: &DataRowBody, + columns: &[Column], + raw_output: bool, + array_mode: bool, +) -> Result { + let ranges: Vec>> = row.ranges().collect()?; + let iter = std::iter::zip(ranges, columns) + .enumerate() + .map(|(i, (range, column))| { + let name = &column.name; + let pg_value = range + .map(|r| { + std::str::from_utf8(&row.buffer()[r]) + .map_err(|e| error::Error::from_sql(e.into(), i)) + }) + .transpose()?; + // let pg_value = row.as_text(i)?; + let json_value = if raw_output { + match pg_value { + Some(v) => Value::String(v.to_string()), + None => Value::Null, + } + } else { + pg_text_to_json(pg_value, column.type_.as_ref())? + }; + Ok((name, json_value)) + }); + + if array_mode { + // drop keys and aggregate into array + let arr = iter + .map(|r| r.map(|(_key, val)| val)) + .collect::, anyhow::Error>>()?; + Ok(Value::Array(arr)) + } else { + let obj = iter + .map(|r| r.map(|(key, val)| (key.to_owned(), val))) + .collect::, anyhow::Error>>()?; Ok(Value::Object(obj)) } } @@ -491,19 +546,22 @@ pub fn pg_text_row_to_json( // // Convert postgres text-encoded value to JSON value // -pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { +pub fn pg_text_to_json( + pg_value: Option<&str>, + pg_type: Option<&Type>, +) -> Result { if let Some(val) = pg_value { - if let Kind::Array(elem_type) = pg_type.kind() { - return pg_array_parse(val, elem_type); + if let Some(Kind::Array(elem_type)) = pg_type.map(|t| t.kind()) { + return pg_array_parse(val, Some(elem_type)); } - match *pg_type { - Type::BOOL => Ok(Value::Bool(val == "t")), - Type::INT2 | Type::INT4 => { + match pg_type { + Some(&Type::BOOL) => Ok(Value::Bool(val == "t")), + Some(&Type::INT2 | &Type::INT4) => { let val = val.parse::()?; Ok(Value::Number(serde_json::Number::from(val))) } - Type::FLOAT4 | Type::FLOAT8 => { + Some(&Type::FLOAT4 | &Type::FLOAT8) => { let fval = val.parse::()?; let num = serde_json::Number::from_f64(fval); if let Some(num) = num { @@ -515,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result Ok(serde_json::from_str(val)?), + Some(&Type::JSON | &Type::JSONB) => Ok(serde_json::from_str(val)?), _ => Ok(Value::String(val.to_string())), } } else { @@ -530,13 +588,13 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result Result { +fn pg_array_parse(pg_array: &str, elem_type: Option<&Type>) -> Result { _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) } fn _pg_array_parse( pg_array: &str, - elem_type: &Type, + elem_type: Option<&Type>, nested: bool, ) -> Result<(Value, usize), anyhow::Error> { let mut pg_array_chr = pg_array.char_indices(); @@ -557,7 +615,7 @@ fn _pg_array_parse( fn push_checked( entry: &mut String, entries: &mut Vec, - elem_type: &Type, + elem_type: Option<&Type>, ) -> Result<(), anyhow::Error> { if !entry.is_empty() { // While in usual postgres response we get nulls as None and everything else @@ -683,34 +741,43 @@ mod tests { #[test] fn test_atomic_types_parse() { assert_eq!( - pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), + pg_text_to_json(Some("foo"), Some(&Type::TEXT)).unwrap(), json!("foo") ); - assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); assert_eq!( - pg_text_to_json(Some("42"), &Type::INT8).unwrap(), + pg_text_to_json(None, Some(&Type::TEXT)).unwrap(), + json!(null) + ); + assert_eq!( + pg_text_to_json(Some("42"), Some(&Type::INT4)).unwrap(), + json!(42) + ); + assert_eq!( + pg_text_to_json(Some("42"), Some(&Type::INT2)).unwrap(), + json!(42) + ); + assert_eq!( + pg_text_to_json(Some("42"), Some(&Type::INT8)).unwrap(), json!("42") ); assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), + pg_text_to_json(Some("42.42"), Some(&Type::FLOAT8)).unwrap(), json!(42.42) ); assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), + pg_text_to_json(Some("42.42"), Some(&Type::FLOAT4)).unwrap(), json!(42.42) ); assert_eq!( - pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), + pg_text_to_json(Some("NaN"), Some(&Type::FLOAT4)).unwrap(), json!("NaN") ); assert_eq!( - pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), + pg_text_to_json(Some("Infinity"), Some(&Type::FLOAT4)).unwrap(), json!("Infinity") ); assert_eq!( - pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), + pg_text_to_json(Some("-Infinity"), Some(&Type::FLOAT4)).unwrap(), json!("-Infinity") ); @@ -720,7 +787,7 @@ mod tests { assert_eq!( pg_text_to_json( Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), - &Type::JSONB + Some(&Type::JSONB) ) .unwrap(), json @@ -730,7 +797,7 @@ mod tests { #[test] fn test_pg_array_parse_text() { fn pt(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::TEXT).unwrap() + pg_array_parse(pg_arr, Some(&Type::TEXT)).unwrap() } assert_eq!( pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), @@ -753,7 +820,7 @@ mod tests { #[test] fn test_pg_array_parse_bool() { fn pb(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::BOOL).unwrap() + pg_array_parse(pg_arr, Some(&Type::BOOL)).unwrap() } assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); @@ -770,7 +837,7 @@ mod tests { #[test] fn test_pg_array_parse_numbers() { fn pn(pg_arr: &str, ty: &Type) -> Value { - pg_array_parse(pg_arr, ty).unwrap() + pg_array_parse(pg_arr, Some(ty)).unwrap() } assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); @@ -798,7 +865,7 @@ mod tests { #[test] fn test_pg_array_with_decoration() { fn p(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::INT2).unwrap() + pg_array_parse(pg_arr, Some(&Type::INT2)).unwrap() } assert_eq!( p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), diff --git a/proxy/src/http/sql_over_http/codec.rs b/proxy/src/http/sql_over_http/codec.rs index fbfe7871bc..5419a89d2a 100644 --- a/proxy/src/http/sql_over_http/codec.rs +++ b/proxy/src/http/sql_over_http/codec.rs @@ -17,7 +17,7 @@ pub enum BackendMessage { Async(backend::Message), } -pub struct BackendMessages(BytesMut); +pub struct BackendMessages(pub BytesMut); impl BackendMessages { pub fn empty() -> BackendMessages { diff --git a/proxy/src/http/sql_over_http/connection.rs b/proxy/src/http/sql_over_http/connection.rs index 75636cffed..533b76c4b7 100644 --- a/proxy/src/http/sql_over_http/connection.rs +++ b/proxy/src/http/sql_over_http/connection.rs @@ -3,12 +3,13 @@ use super::error::Error; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures::channel::mpsc; -use futures::{stream::FusedStream, Sink, Stream, StreamExt}; +use futures::SinkExt; +use futures::{Sink, StreamExt}; use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; use std::collections::{HashMap, VecDeque}; +use std::future::poll_fn; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::maybe_tls_stream::MaybeTlsStream; use tokio_util::codec::Framed; @@ -16,8 +17,6 @@ use tracing::trace; pub enum RequestMessages { Single(FrontendMessage), - // CopyIn(CopyInReceiver), - // CopyBoth(CopyBothReceiver), } pub struct Request { @@ -29,12 +28,12 @@ pub struct Response { sender: mpsc::Sender, } -#[derive(PartialEq, Debug)] -enum State { - Active, - Terminating, - Closing, -} +// #[derive(PartialEq, Debug)] +// enum State { +// Active, +// Terminating, +// Closing, +// } /// A connection to a PostgreSQL database. /// @@ -49,11 +48,12 @@ pub struct Connection { pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + // receiver: mpsc::UnboundedReceiver, pending_request: Option, - pending_responses: VecDeque, - responses: VecDeque, - state: State, + pending_responses: VecDeque<(BackendMessages, bool)>, + pub buf: BytesMut, + // responses: VecDeque, + // state: State, } impl Connection @@ -63,18 +63,49 @@ where { pub(crate) fn new( stream: Framed, PostgresCodec>, - pending_responses: VecDeque, + pending_responses: VecDeque<(BackendMessages, bool)>, parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + // receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, parameters, - receiver, + // receiver, pending_request: None, pending_responses, - responses: VecDeque::new(), - state: State::Active, + buf: BytesMut::new(), + // responses: VecDeque::new(), + // state: State::Active, + } + } + + pub async fn send(&mut self) -> Result<(), Error> { + poll_fn(|cx| self.poll_send(cx)).await?; + let request = FrontendMessage::Raw(self.buf.split().freeze()); + self.stream.start_send_unpin(request).map_err(Error::io) + } + + pub async fn flush(&mut self) -> Result<(), Error> { + poll_fn(|cx| self.poll_flush(cx)).await + } + + pub async fn next_response(&mut self) -> Result<(BackendMessages, bool), Error> { + match self.pending_responses.pop_front() { + Some((a, b)) => Ok((a, b)), + None => poll_fn(|cx| self.poll_read(cx)).await, + } + } + + pub async fn next_message(&mut self) -> Result { + loop { + let (mut messages, complete) = self.next_response().await?; + if let Some(message) = messages.next().map_err(Error::parse)? { + self.pending_responses.push_front((messages, complete)); + break Ok(message); + } + if complete { + break Err(Error::unexpected_message()); + } } } @@ -82,33 +113,19 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll>> { - if let Some(message) = self.pending_responses.pop_front() { - trace!("retrying pending response"); - return Poll::Ready(Some(Ok(message))); - } - - Pin::new(&mut self.stream) - .poll_next(cx) + self.stream + .poll_next_unpin(cx) .map(|o| o.map(|r| r.map_err(Error::io))) } - fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { - if self.state != State::Active { - trace!("poll_read: done"); - return Ok(None); - } - + fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - let message = match self.poll_response(cx)? { - Poll::Ready(Some(message)) => message, - Poll::Ready(None) => return Err(Error::closed()), - Poll::Pending => { - trace!("poll_read: waiting on response"); - return Ok(None); - } + let message = match ready!(self.poll_response(cx)?) { + Some(message) => message, + None => return Poll::Ready(Err(Error::closed())), }; - let (mut messages, request_complete) = match message { + match message { BackendMessage::Async(Message::NoticeResponse(body)) => { // TODO: log this @@ -138,169 +155,12 @@ where BackendMessage::Normal { messages, request_complete, - } => (messages, request_complete), + } => return Poll::Ready(Ok((messages, request_complete))), }; - - let mut response = match self.responses.pop_front() { - Some(response) => response, - None => match messages.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), - }, - }; - - match response.sender.poll_ready(cx) { - Poll::Ready(Ok(())) => { - let _ = response.sender.start_send(messages); - if !request_complete { - self.responses.push_front(response); - } - } - Poll::Ready(Err(_)) => { - // we need to keep paging through the rest of the messages even if the receiver's hung up - if !request_complete { - self.responses.push_front(response); - } - } - Poll::Pending => { - self.responses.push_front(response); - self.pending_responses.push_back(BackendMessage::Normal { - messages, - request_complete, - }); - trace!("poll_read: waiting on sender"); - return Ok(None); - } - } } } - fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(messages) = self.pending_request.take() { - trace!("retrying pending request"); - return Poll::Ready(Some(messages)); - } - - if self.receiver.is_terminated() { - return Poll::Ready(None); - } - - match self.receiver.poll_next_unpin(cx) { - Poll::Ready(Some(request)) => { - trace!("polled new request"); - self.responses.push_back(Response { - sender: request.sender, - }); - Poll::Ready(Some(request.messages)) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } - - fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { - loop { - if self.state == State::Closing { - trace!("poll_write: done"); - return Ok(false); - } - - if Pin::new(&mut self.stream) - .poll_ready(cx) - .map_err(Error::io)? - .is_pending() - { - trace!("poll_write: waiting on socket"); - return Ok(false); - } - - let request = match self.poll_request(cx) { - Poll::Ready(Some(request)) => request, - Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { - trace!("poll_write: at eof, terminating"); - self.state = State::Terminating; - let mut request = BytesMut::new(); - frontend::terminate(&mut request); - RequestMessages::Single(FrontendMessage::Raw(request.freeze())) - } - Poll::Ready(None) => { - trace!( - "poll_write: at eof, pending responses {}", - self.responses.len() - ); - return Ok(true); - } - Poll::Pending => { - trace!("poll_write: waiting on request"); - return Ok(true); - } - }; - - match request { - RequestMessages::Single(request) => { - Pin::new(&mut self.stream) - .start_send(request) - .map_err(Error::io)?; - if self.state == State::Terminating { - trace!("poll_write: sent eof, closing"); - self.state = State::Closing; - } - } // RequestMessages::CopyIn(mut receiver) => { - // let message = match receiver.poll_next_unpin(cx) { - // Poll::Ready(Some(message)) => message, - // Poll::Ready(None) => { - // trace!("poll_write: finished copy_in request"); - // continue; - // } - // Poll::Pending => { - // trace!("poll_write: waiting on copy_in stream"); - // self.pending_request = Some(RequestMessages::CopyIn(receiver)); - // return Ok(true); - // } - // }; - // Pin::new(&mut self.stream) - // .start_send(message) - // .map_err(Error::io)?; - // self.pending_request = Some(RequestMessages::CopyIn(receiver)); - // } - // RequestMessages::CopyBoth(mut receiver) => { - // let message = match receiver.poll_next_unpin(cx) { - // Poll::Ready(Some(message)) => message, - // Poll::Ready(None) => { - // trace!("poll_write: finished copy_both request"); - // continue; - // } - // Poll::Pending => { - // trace!("poll_write: waiting on copy_both stream"); - // self.pending_request = Some(RequestMessages::CopyBoth(receiver)); - // return Ok(true); - // } - // }; - // Pin::new(&mut self.stream) - // .start_send(message) - // .map_err(Error::io)?; - // self.pending_request = Some(RequestMessages::CopyBoth(receiver)); - // } - } - } - } - - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { - match Pin::new(&mut self.stream) - .poll_flush(cx) - .map_err(Error::io)? - { - Poll::Ready(()) => trace!("poll_flush: flushed"), - Poll::Pending => trace!("poll_flush: waiting on socket"), - } - Ok(()) - } - fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.state != State::Closing { - return Poll::Pending; - } - match Pin::new(&mut self.stream) .poll_close(cx) .map_err(Error::io)? @@ -321,40 +181,17 @@ where self.parameters.get(name).map(|s| &**s) } - /// Polls for asynchronous messages from the server. - /// - /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to - /// examine those messages should use this method to drive the connection rather than its `Future` implementation. - pub fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll>> { - let message = self.poll_read(cx)?; - let want_flush = self.poll_write(cx)?; - if want_flush { - self.poll_flush(cx)?; - } - match message { - Some(message) => Poll::Ready(Some(Ok(message))), - None => match self.poll_shutdown(cx) { - Poll::Ready(Ok(())) => Poll::Ready(None), - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), - Poll::Pending => Poll::Pending, - }, - } + fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(msg) = self.poll_read(cx)? { + self.pending_responses.push_back(msg); + }; + self.stream.poll_ready_unpin(cx).map_err(Error::io) + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(msg) = self.poll_read(cx)? { + self.pending_responses.push_back(msg); + }; + self.stream.poll_flush_unpin(cx).map_err(Error::io) } } - -// impl Future for Connection -// where -// S: AsyncRead + AsyncWrite + Unpin, -// T: AsyncRead + AsyncWrite + Unpin, -// { -// type Output = Result<(), Error>; - -// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// while let Some(message) = ready!(self.poll_message(cx)?) { -// if let AsyncMessage::Notice(notice) = message { -// info!("{}: {}", notice.severity(), notice.message()); -// } -// } -// Poll::Ready(Ok(())) -// } -// } diff --git a/proxy/src/http/sql_over_http/error.rs b/proxy/src/http/sql_over_http/error.rs index f77d87f59f..5fb9a5979a 100644 --- a/proxy/src/http/sql_over_http/error.rs +++ b/proxy/src/http/sql_over_http/error.rs @@ -1,13 +1,14 @@ use std::{error, fmt, io}; -use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; -use tokio_postgres::error::{SqlState, ErrorPosition}; use fallible_iterator::FallibleIterator; +use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; +use tokio_postgres::error::{ErrorPosition, SqlState}; #[derive(Debug, PartialEq)] enum Kind { Io, UnexpectedMessage, + FromSql(usize), Closed, Db, Parse, @@ -36,6 +37,7 @@ impl fmt::Display for Error { match &self.0.kind { Kind::Io => fmt.write_str("error communicating with the server")?, Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, Kind::Closed => fmt.write_str("connection closed")?, Kind::Db => fmt.write_str("db error")?, Kind::Parse => fmt.write_str("error parsing response from server")?, @@ -91,6 +93,10 @@ impl Error { } } + pub(crate) fn from_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::FromSql(idx), Some(e)) + } + pub(crate) fn closed() -> Error { Error::new(Kind::Closed, None) }