mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 00:42:54 +00:00
@@ -86,18 +86,24 @@ impl ProxyConnection {
|
||||
};
|
||||
|
||||
// We'll get rid of this once migration to async is complete
|
||||
let db_stream = {
|
||||
let (pg_version, db_stream) = {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let stream = runtime.block_on(conn)?.into_std()?;
|
||||
let (pg_version, stream) = runtime.block_on(conn)?;
|
||||
let stream = stream.into_std()?;
|
||||
stream.set_nonblocking(false)?;
|
||||
stream
|
||||
|
||||
(pg_version, stream)
|
||||
};
|
||||
|
||||
// Let the client send new requests
|
||||
self.pgb.write_message(&Be::ReadyForQuery)?;
|
||||
self.pgb
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&pg_version),
|
||||
))?
|
||||
.write_message(&Be::ReadyForQuery)?;
|
||||
|
||||
Ok((self.pgb.into_stream(), db_stream))
|
||||
}
|
||||
@@ -175,7 +181,7 @@ impl ProxyConnection {
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::ParameterStatus)?;
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
@@ -189,7 +195,7 @@ impl ProxyConnection {
|
||||
// Give user a URL to spawn a new database
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::ParameterStatus)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(greeting))?;
|
||||
|
||||
// Wait for web console response
|
||||
@@ -218,11 +224,21 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
}
|
||||
|
||||
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
||||
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<tokio::net::TcpStream> {
|
||||
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<(String, tokio::net::TcpStream)> {
|
||||
let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?;
|
||||
let config = tokio_postgres::Config::from(db_info);
|
||||
let _ = config.connect_raw(&mut socket, NoTls).await?;
|
||||
Ok(socket)
|
||||
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
|
||||
|
||||
let query = client.query_one("select current_setting('server_version')", &[]);
|
||||
|
||||
tokio::pin!(query, conn);
|
||||
|
||||
let version = tokio::select!(
|
||||
x = query => x?.try_get(0)?,
|
||||
_ = conn => bail!("connection closed too early"),
|
||||
);
|
||||
|
||||
Ok((version, socket))
|
||||
}
|
||||
|
||||
/// Concurrently proxy both directions of the client and server connections
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
|
||||
use crate::pq_proto::{
|
||||
BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode,
|
||||
};
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::{anyhow, bail, ensure, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
@@ -355,11 +357,9 @@ impl PostgresBackend {
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::MD5 => {
|
||||
@@ -410,11 +410,9 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
|
||||
@@ -352,7 +352,7 @@ pub enum BeMessage<'a> {
|
||||
EncryptionResponse(bool),
|
||||
NoData,
|
||||
ParameterDescription,
|
||||
ParameterStatus,
|
||||
ParameterStatus(BeParameterStatusMessage<'a>),
|
||||
ParseComplete,
|
||||
ReadyForQuery,
|
||||
RowDescription(&'a [RowDescriptor<'a>]),
|
||||
@@ -361,6 +361,18 @@ pub enum BeMessage<'a> {
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeParameterStatusMessage<'a> {
|
||||
Encoding(&'a str),
|
||||
ServerVersion(&'a str),
|
||||
}
|
||||
|
||||
impl BeParameterStatusMessage<'static> {
|
||||
pub fn encoding() -> BeMessage<'static> {
|
||||
BeMessage::ParameterStatus(Self::Encoding("UTF8"))
|
||||
}
|
||||
}
|
||||
|
||||
// One row desciption in RowDescription packet.
|
||||
#[derive(Debug)]
|
||||
pub struct RowDescriptor<'a> {
|
||||
@@ -665,12 +677,23 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_u8(response);
|
||||
}
|
||||
|
||||
BeMessage::ParameterStatus => {
|
||||
BeMessage::ParameterStatus(param) => {
|
||||
use std::io::{IoSlice, Write};
|
||||
use BeParameterStatusMessage::*;
|
||||
|
||||
let [name, value] = match param {
|
||||
Encoding(name) => [b"client_encoding", name.as_bytes()],
|
||||
ServerVersion(version) => [b"server_version", version.as_bytes()],
|
||||
};
|
||||
|
||||
// Parameter names and values are passed as null-terminated strings
|
||||
let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
|
||||
let mut buffer = [0u8; 64]; // this should be enough
|
||||
let cnt = buffer.as_mut().write_vectored(iov).unwrap();
|
||||
|
||||
buf.put_u8(b'S');
|
||||
// parameter names and values are specified by null terminated strings
|
||||
const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0";
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(PARAM_NAME_VALUE);
|
||||
buf.put_slice(&buffer[..cnt]);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user