From 0ccfc62e88b0dc82c31ffe10fad89b0cda7bf88f Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Tue, 26 Oct 2021 17:54:10 +0300 Subject: [PATCH] [proxy] Pass PostgreSQL version to client Fixes #779 --- proxy/src/proxy.rs | 34 ++++++++++++++++++++-------- zenith_utils/src/postgres_backend.rs | 20 ++++++++-------- zenith_utils/src/pq_proto.rs | 33 +++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 25 deletions(-) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1eff0315a8..ad0af57eea 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -86,18 +86,24 @@ impl ProxyConnection { }; // We'll get rid of this once migration to async is complete - let db_stream = { + let (pg_version, db_stream) = { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let stream = runtime.block_on(conn)?.into_std()?; + let (pg_version, stream) = runtime.block_on(conn)?; + let stream = stream.into_std()?; stream.set_nonblocking(false)?; - stream + + (pg_version, stream) }; // Let the client send new requests - self.pgb.write_message(&Be::ReadyForQuery)?; + self.pgb + .write_message_noflush(&BeMessage::ParameterStatus( + BeParameterStatusMessage::ServerVersion(&pg_version), + ))? + .write_message(&Be::ReadyForQuery)?; Ok((self.pgb.into_stream(), db_stream)) } @@ -175,7 +181,7 @@ impl ProxyConnection { self.pgb .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::ParameterStatus)?; + .write_message_noflush(&BeParameterStatusMessage::encoding())?; Ok(db_info) } @@ -189,7 +195,7 @@ impl ProxyConnection { // Give user a URL to spawn a new database self.pgb .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::ParameterStatus)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? .write_message(&Be::NoticeResponse(greeting))?; // Wait for web console response @@ -218,11 +224,21 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String { } /// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message -async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result { +async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<(String, tokio::net::TcpStream)> { let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?; let config = tokio_postgres::Config::from(db_info); - let _ = config.connect_raw(&mut socket, NoTls).await?; - Ok(socket) + let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; + + let query = client.query_one("select current_setting('server_version')", &[]); + + tokio::pin!(query, conn); + + let version = tokio::select!( + x = query => x?.try_get(0)?, + _ = conn => bail!("connection closed too early"), + ); + + Ok((version, socket)) } /// Concurrently proxy both directions of the client and server connections diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 4afcf2a554..4d4c6278ba 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -3,7 +3,9 @@ //! implementation determining how to process the queries. Currently its API //! is rather narrow, but we can extend it once required. -use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode}; +use crate::pq_proto::{ + BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode, +}; use crate::sock_split::{BidiStream, ReadStream, WriteStream}; use anyhow::{anyhow, bail, ensure, Result}; use bytes::{Bytes, BytesMut}; @@ -355,11 +357,9 @@ impl PostgresBackend { match self.auth_type { AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; } AuthType::MD5 => { @@ -410,11 +410,9 @@ impl PostgresBackend { } } } - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index 1941784332..47c70ac37f 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -352,7 +352,7 @@ pub enum BeMessage<'a> { EncryptionResponse(bool), NoData, ParameterDescription, - ParameterStatus, + ParameterStatus(BeParameterStatusMessage<'a>), ParseComplete, ReadyForQuery, RowDescription(&'a [RowDescriptor<'a>]), @@ -361,6 +361,18 @@ pub enum BeMessage<'a> { KeepAlive(WalSndKeepAlive), } +#[derive(Debug)] +pub enum BeParameterStatusMessage<'a> { + Encoding(&'a str), + ServerVersion(&'a str), +} + +impl BeParameterStatusMessage<'static> { + pub fn encoding() -> BeMessage<'static> { + BeMessage::ParameterStatus(Self::Encoding("UTF8")) + } +} + // One row desciption in RowDescription packet. #[derive(Debug)] pub struct RowDescriptor<'a> { @@ -665,12 +677,23 @@ impl<'a> BeMessage<'a> { buf.put_u8(response); } - BeMessage::ParameterStatus => { + BeMessage::ParameterStatus(param) => { + use std::io::{IoSlice, Write}; + use BeParameterStatusMessage::*; + + let [name, value] = match param { + Encoding(name) => [b"client_encoding", name.as_bytes()], + ServerVersion(version) => [b"server_version", version.as_bytes()], + }; + + // Parameter names and values are passed as null-terminated strings + let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new); + let mut buffer = [0u8; 64]; // this should be enough + let cnt = buffer.as_mut().write_vectored(iov).unwrap(); + buf.put_u8(b'S'); - // parameter names and values are specified by null terminated strings - const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0"; write_body(buf, |buf| { - buf.put_slice(PARAM_NAME_VALUE); + buf.put_slice(&buffer[..cnt]); Ok::<_, io::Error>(()) }) .unwrap();