From e8400d9d938e9e0fd16fb626efaa848b6d427c16 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 31 Jul 2023 10:42:56 +0100 Subject: [PATCH] stash --- proxy/src/http/sql_over_http.rs | 84 ++++++++++++++++++----- proxy/src/pg_client/connection.rs | 106 +++++++++++++++++++----------- 2 files changed, 134 insertions(+), 56 deletions(-) diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index 5423028df5..edeeaf1d27 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -364,13 +364,13 @@ async fn query_to_json( })) } -/// Pass text directly to the Postgres backend to allow it to sort out typing itself and -/// to save a roundtrip -async fn query_raw_txt<'a, St, T>( +async fn query_raw_txt_as_json<'a, St, T>( conn: &mut connection::Connection, query: String, params: Vec>, -) -> Result, pg_client::error::Error> + raw_output: bool, + array_mode: bool, +) -> Result where St: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, @@ -380,20 +380,70 @@ where conn.prepare_and_execute("", "", query.as_str(), params)?; conn.sync().await?; - let mut columns = vec![]; - if let Some((desc, rows)) = conn.stream_query_results().await? { - let mut it = desc.fields(); - while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? { - let type_ = Type::from_oid(field.type_oid()); - // let column = Column::new(field.name().to_string(), type_, field); - columns.push(Column { - name: field.name().to_string(), - type_, - }); - } - } + let mut fields = vec![]; + let mut rows = vec![]; + let command_tag = match conn.stream_query_results().await? { + connection::QueryResult::NoRows(tag) => tag, + connection::QueryResult::Rows { + row_description, + mut row_stream, + } => { + let mut columns = vec![]; + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? { + fields.push(json!({ + "name": Value::String(field.name().to_owned()), + "dataTypeID": Value::Number(field.type_oid().into()), + "tableID": field.table_oid(), + "columnID": field.column_id(), + "dataTypeSize": field.type_size(), + "dataTypeModifier": field.type_modifier(), + "format": "text", + })); - Ok(columns) + let type_ = Type::from_oid(field.type_oid()); + // let column = Column::new(field.name().to_string(), type_, field); + columns.push(Column { + name: field.name().to_string(), + type_, + }); + } + + let mut curret_size = 0; + while let Some(row) = row_stream.next().await.transpose()? { + curret_size += row.buffer().len(); + if curret_size > MAX_RESPONSE_SIZE { + todo!() + // return Err(anyhow::anyhow!("response too large")); + } + + rows.push(pg_text_row_to_json2(&row, &columns, raw_output, array_mode).unwrap()); + } + + row_stream.tag() + } + }; + + let command_tag = command_tag.tag()?; + let mut command_tag_split = command_tag.split(' '); + let command_tag_name = command_tag_split.next().unwrap_or_default(); + let command_tag_count = if command_tag_name == "INSERT" { + // INSERT returns OID first and then number of rows + command_tag_split.nth(1) + } else { + // other commands return number of rows (if any) + command_tag_split.next() + } + .and_then(|s| s.parse::().ok()); + + // resulting JSON format is based on the format of node-postgres result + Ok(json!({ + "command": command_tag_name, + "rowCount": command_tag_count, + "rows": rows, + "fields": fields, + "rowAsArray": array_mode, + })) } struct Column { diff --git a/proxy/src/pg_client/connection.rs b/proxy/src/pg_client/connection.rs index 7e5f0d88b3..1e21faec7f 100644 --- a/proxy/src/pg_client/connection.rs +++ b/proxy/src/pg_client/connection.rs @@ -7,7 +7,8 @@ use futures::{Sink, StreamExt}; use futures::{SinkExt, Stream}; use postgres_protocol::authentication; use postgres_protocol::message::backend::{ - BackendKeyDataBody, DataRowBody, Message, ReadyForQueryBody, RowDescriptionBody, + BackendKeyDataBody, CommandCompleteBody, DataRowBody, Message, ReadyForQueryBody, + RowDescriptionBody, }; use postgres_protocol::message::frontend; use std::collections::{HashMap, VecDeque}; @@ -311,48 +312,32 @@ impl Conne self.raw.send().await } - /// returns None if there's no row data - /// returns Some with the row description and a row stream if there is row data - pub async fn stream_query_results( - &mut self, - ) -> Result< - Option<( - RowDescriptionBody, - impl Stream> + '_, - )>, - Error, - > { + pub async fn wait_for_prepare(&mut self) -> Result, Error> { let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) }; let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) }; match self.raw.next_message().await? { - Message::RowDescription(desc) => { - struct RowStream<'a, S, T> { - raw: &'a mut RawConnection, - } - impl Unpin for RowStream<'_, S, T> {} - - impl Stream - for RowStream<'_, S, T> - { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match ready!(self.raw.poll_read(cx)?) { - Message::DataRow(row) => Poll::Ready(Some(Ok(row))), - Message::CommandComplete(_) => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::expecting("command completion")))), - } - } - } - - Ok(Some((desc, RowStream { raw: &mut self.raw }))) - } + Message::RowDescription(desc) => Ok(QueryResult::Rows { + row_stream: RowStream::Stream(&mut self.raw), + row_description: desc, + }), Message::NoData => { - let Message::CommandComplete(_) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) }; - Ok(None) + let Message::CommandComplete(tag) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) }; + Ok(QueryResult::NoRows(tag)) + } + _ => Err(Error::expecting("query results")), + } + } + pub async fn stream_query_results(&mut self) -> Result, Error> { + let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) }; + let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) }; + match self.raw.next_message().await? { + Message::RowDescription(desc) => Ok(QueryResult::Rows { + row_stream: RowStream::Stream(&mut self.raw), + row_description: desc, + }), + Message::NoData => { + let Message::CommandComplete(tag) = self.raw.next_message().await? else { return Err(Error::expecting("command completion")) }; + Ok(QueryResult::NoRows(tag)) } _ => Err(Error::expecting("query results")), } @@ -367,3 +352,46 @@ impl Conne } } } + +// pub enum QueryResult<'a, S, T> { +// NoRows(CommandCompleteBody), +// Rows { +// row_description: RowDescriptionBody, +// row_stream: RowStream<'a, S, T>, +// }, +// } + +pub enum RowStream<'a, S, T> { + Stream(&'a mut RawConnection), + Complete(CommandCompleteBody), +} +impl Unpin for RowStream<'_, S, T> {} + +impl Stream + for RowStream<'_, S, T> +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + RowStream::Stream(raw) => match ready!(raw.poll_read(cx)?) { + Message::DataRow(row) => Poll::Ready(Some(Ok(row))), + Message::CommandComplete(tag) => { + *self = Self::Complete(tag); + Poll::Ready(None) + } + _ => Poll::Ready(Some(Err(Error::expecting("command completion")))), + }, + RowStream::Complete(_) => Poll::Ready(None), + } + } +} + +impl RowStream<'_, S, T> { + pub fn tag(self) -> CommandCompleteBody { + match self { + RowStream::Stream(_) => panic!("should not get tag unless row stream is exhausted"), + RowStream::Complete(tag) => tag, + } + } +}