[proxy] Pass PostgreSQL version to client

Fixes #779
This commit is contained in:
Dmitry Ivanov
2021-10-26 17:54:10 +03:00
parent b55cf773a8
commit 0ccfc62e88
3 changed files with 62 additions and 25 deletions

View File

@@ -86,18 +86,24 @@ impl ProxyConnection {
};
// We'll get rid of this once migration to async is complete
let db_stream = {
let (pg_version, db_stream) = {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let stream = runtime.block_on(conn)?.into_std()?;
let (pg_version, stream) = runtime.block_on(conn)?;
let stream = stream.into_std()?;
stream.set_nonblocking(false)?;
stream
(pg_version, stream)
};
// Let the client send new requests
self.pgb.write_message(&Be::ReadyForQuery)?;
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&pg_version),
))?
.write_message(&Be::ReadyForQuery)?;
Ok((self.pgb.into_stream(), db_stream))
}
@@ -175,7 +181,7 @@ impl ProxyConnection {
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::ParameterStatus)?;
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
@@ -189,7 +195,7 @@ impl ProxyConnection {
// Give user a URL to spawn a new database
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::ParameterStatus)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))?;
// Wait for web console response
@@ -218,11 +224,21 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String {
}
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<tokio::net::TcpStream> {
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<(String, tokio::net::TcpStream)> {
let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?;
let config = tokio_postgres::Config::from(db_info);
let _ = config.connect_raw(&mut socket, NoTls).await?;
Ok(socket)
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
let query = client.query_one("select current_setting('server_version')", &[]);
tokio::pin!(query, conn);
let version = tokio::select!(
x = query => x?.try_get(0)?,
_ = conn => bail!("connection closed too early"),
);
Ok((version, socket))
}
/// Concurrently proxy both directions of the client and server connections

View File

@@ -3,7 +3,9 @@
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
use crate::pq_proto::{
BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode,
};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::{anyhow, bail, ensure, Result};
use bytes::{Bytes, BytesMut};
@@ -355,11 +357,9 @@ impl PostgresBackend {
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::MD5 => {
@@ -410,11 +410,9 @@ impl PostgresBackend {
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}

View File

@@ -352,7 +352,7 @@ pub enum BeMessage<'a> {
EncryptionResponse(bool),
NoData,
ParameterDescription,
ParameterStatus,
ParameterStatus(BeParameterStatusMessage<'a>),
ParseComplete,
ReadyForQuery,
RowDescription(&'a [RowDescriptor<'a>]),
@@ -361,6 +361,18 @@ pub enum BeMessage<'a> {
KeepAlive(WalSndKeepAlive),
}
#[derive(Debug)]
pub enum BeParameterStatusMessage<'a> {
Encoding(&'a str),
ServerVersion(&'a str),
}
impl BeParameterStatusMessage<'static> {
pub fn encoding() -> BeMessage<'static> {
BeMessage::ParameterStatus(Self::Encoding("UTF8"))
}
}
// One row desciption in RowDescription packet.
#[derive(Debug)]
pub struct RowDescriptor<'a> {
@@ -665,12 +677,23 @@ impl<'a> BeMessage<'a> {
buf.put_u8(response);
}
BeMessage::ParameterStatus => {
BeMessage::ParameterStatus(param) => {
use std::io::{IoSlice, Write};
use BeParameterStatusMessage::*;
let [name, value] = match param {
Encoding(name) => [b"client_encoding", name.as_bytes()],
ServerVersion(version) => [b"server_version", version.as_bytes()],
};
// Parameter names and values are passed as null-terminated strings
let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
let mut buffer = [0u8; 64]; // this should be enough
let cnt = buffer.as_mut().write_vectored(iov).unwrap();
buf.put_u8(b'S');
// parameter names and values are specified by null terminated strings
const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0";
write_body(buf, |buf| {
buf.put_slice(PARAM_NAME_VALUE);
buf.put_slice(&buffer[..cnt]);
Ok::<_, io::Error>(())
})
.unwrap();