mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 15:02:56 +00:00
[proxy] Forward compute connection params to client
This fixes all kinds of problems related to missing params, like broken timestamps (due to `integer_datetimes`). This solution is not ideal, but it will help. Meanwhile, I'm going to dedicate some time to improving connection machinery. Note that this **does not** fix problems with passing certain parameters in a reverse direction, i.e. **from client to compute**. This is a separate matter and will be dealt with in an upcoming PR.
This commit is contained in:
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -2613,7 +2613,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres"
|
||||
version = "0.19.2"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=d052ee8b86fff9897c77b0fe89ea9daba0e1fa38#d052ee8b86fff9897c77b0fe89ea9daba0e1fa38"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=69c1ef71cd5418cf063d4ca21eadc3427980caea#69c1ef71cd5418cf063d4ca21eadc3427980caea"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -2626,7 +2626,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=d052ee8b86fff9897c77b0fe89ea9daba0e1fa38#d052ee8b86fff9897c77b0fe89ea9daba0e1fa38"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=69c1ef71cd5418cf063d4ca21eadc3427980caea#69c1ef71cd5418cf063d4ca21eadc3427980caea"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
@@ -2644,7 +2644,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.3"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=d052ee8b86fff9897c77b0fe89ea9daba0e1fa38#d052ee8b86fff9897c77b0fe89ea9daba0e1fa38"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=69c1ef71cd5418cf063d4ca21eadc3427980caea#69c1ef71cd5418cf063d4ca21eadc3427980caea"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -4010,7 +4010,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.6"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=d052ee8b86fff9897c77b0fe89ea9daba0e1fa38#d052ee8b86fff9897c77b0fe89ea9daba0e1fa38"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=69c1ef71cd5418cf063d4ca21eadc3427980caea#69c1ef71cd5418cf063d4ca21eadc3427980caea"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
|
||||
@@ -86,4 +86,4 @@ lto = true
|
||||
# This is only needed for proxy's tests.
|
||||
# TODO: we should probably fork `tokio-postgres-rustls` instead.
|
||||
[patch.crates-io]
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
|
||||
@@ -12,12 +12,12 @@ futures = "0.3.13"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
log = { version = "0.4", features = ["std", "serde"] }
|
||||
notify = "5.0.0"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
regex = "1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tar = "0.4"
|
||||
tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
url = "2.2.2"
|
||||
workspace_hack = { version = "0.1", path = "../workspace_hack" }
|
||||
|
||||
@@ -10,7 +10,7 @@ comfy-table = "6.1"
|
||||
git-version = "0.3.5"
|
||||
nix = "0.25"
|
||||
once_cell = "1.13.0"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
regex = "1"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
||||
@@ -8,8 +8,8 @@ edition = "2021"
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
itertools = "0.10.3"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
url = "2.2.2"
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.9"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
wal_craft = { path = "wal_craft" }
|
||||
|
||||
[build-dependencies]
|
||||
|
||||
@@ -11,7 +11,7 @@ clap = "4.0"
|
||||
env_logger = "0.9"
|
||||
log = "0.4"
|
||||
once_cell = "1.13.0"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
postgres_ffi = { path = "../" }
|
||||
tempfile = "3.2"
|
||||
workspace_hack = { version = "0.1", path = "../../../workspace_hack" }
|
||||
|
||||
@@ -7,7 +7,7 @@ edition = "2021"
|
||||
anyhow = "1.0"
|
||||
bytes = "1.0.1"
|
||||
pin-project-lite = "0.2.7"
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
rand = "0.8.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
|
||||
@@ -463,7 +463,10 @@ pub enum BeMessage<'a> {
|
||||
EncryptionResponse(bool),
|
||||
NoData,
|
||||
ParameterDescription,
|
||||
ParameterStatus(BeParameterStatusMessage<'a>),
|
||||
ParameterStatus {
|
||||
name: &'a [u8],
|
||||
value: &'a [u8],
|
||||
},
|
||||
ParseComplete,
|
||||
ReadyForQuery,
|
||||
RowDescription(&'a [RowDescriptor<'a>]),
|
||||
@@ -472,6 +475,28 @@ pub enum BeMessage<'a> {
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
}
|
||||
|
||||
/// Common shorthands.
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
|
||||
/// This is a sensible default, given that:
|
||||
/// * rust strings only support this encoding out of the box.
|
||||
/// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
|
||||
///
|
||||
/// TODO: do we need to report `server_encoding` as well?
|
||||
pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
|
||||
name: b"client_encoding",
|
||||
value: b"UTF8",
|
||||
};
|
||||
|
||||
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
|
||||
pub fn server_version(version: &'a str) -> Self {
|
||||
Self::ParameterStatus {
|
||||
name: b"server_version",
|
||||
value: version.as_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeAuthenticationSaslMessage<'a> {
|
||||
Methods(&'a [&'a str]),
|
||||
@@ -485,12 +510,6 @@ pub enum BeParameterStatusMessage<'a> {
|
||||
ServerVersion(&'a str),
|
||||
}
|
||||
|
||||
impl BeParameterStatusMessage<'static> {
|
||||
pub fn encoding() -> BeMessage<'static> {
|
||||
BeMessage::ParameterStatus(Self::Encoding("UTF8"))
|
||||
}
|
||||
}
|
||||
|
||||
// One row description in RowDescription packet.
|
||||
#[derive(Debug)]
|
||||
pub struct RowDescriptor<'a> {
|
||||
@@ -587,14 +606,15 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
if s.contains(&0) {
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
let bytes = s.as_ref();
|
||||
if bytes.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(s);
|
||||
buf.put_slice(bytes);
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
}
|
||||
@@ -644,7 +664,7 @@ impl<'a> BeMessage<'a> {
|
||||
Methods(methods) => {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods.iter() {
|
||||
write_cstr(method.as_bytes(), buf)?;
|
||||
write_cstr(method, buf)?;
|
||||
}
|
||||
buf.put_u8(0); // zero terminator for the list
|
||||
}
|
||||
@@ -759,7 +779,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_slice(b"CXX000\0");
|
||||
|
||||
buf.put_u8(b'M'); // the message
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
write_cstr(error_msg, buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
@@ -799,24 +819,12 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_u8(response);
|
||||
}
|
||||
|
||||
BeMessage::ParameterStatus(param) => {
|
||||
use std::io::{IoSlice, Write};
|
||||
use BeParameterStatusMessage::*;
|
||||
|
||||
let [name, value] = match param {
|
||||
Encoding(name) => [b"client_encoding", name.as_bytes()],
|
||||
ServerVersion(version) => [b"server_version", version.as_bytes()],
|
||||
};
|
||||
|
||||
// Parameter names and values are passed as null-terminated strings
|
||||
let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
|
||||
let mut buffer = [0u8; 64]; // this should be enough
|
||||
let cnt = buffer.as_mut().write_vectored(iov).unwrap();
|
||||
|
||||
BeMessage::ParameterStatus { name, value } => {
|
||||
buf.put_u8(b'S');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(&buffer[..cnt]);
|
||||
});
|
||||
write_cstr(name, buf)?;
|
||||
write_cstr(value, buf)
|
||||
})?;
|
||||
}
|
||||
|
||||
BeMessage::ParameterDescription => {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
@@ -361,11 +361,9 @@ impl PostgresBackend {
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion("14.1"),
|
||||
))?
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
@@ -413,7 +411,7 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
use crate::postgres_backend::AuthType;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use rand::Rng;
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
@@ -331,11 +331,9 @@ impl PostgresBackend {
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion("14.1"),
|
||||
))?
|
||||
.write_message(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
@@ -384,7 +382,7 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
@@ -36,9 +36,9 @@ nix = "0.25"
|
||||
num-traits = "0.2.15"
|
||||
once_cell = "1.13.0"
|
||||
pin-project-lite = "0.2.7"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
pprof = { git = "https://github.com/neondatabase/pprof-rs.git", branch = "wallclock-profiling", features = ["flamegraph"], optional = true }
|
||||
rand = "0.8.3"
|
||||
regex = "1.4.5"
|
||||
@@ -52,7 +52,7 @@ svg_fmt = "0.4.1"
|
||||
tar = "0.4.33"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.17", features = ["process", "sync", "macros", "fs", "rt", "io-util", "time"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
tokio-util = { version = "0.7.3", features = ["io", "io-util"] }
|
||||
toml_edit = { version = "0.14", features = ["easy"] }
|
||||
tracing = "0.1.36"
|
||||
|
||||
@@ -33,7 +33,7 @@ sha2 = "0.10.2"
|
||||
socket2 = "0.4.4"
|
||||
thiserror = "1.0.30"
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
tokio-rustls = "0.23.0"
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{AuthSuccess, NodeInfo};
|
||||
use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters};
|
||||
use pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
@@ -60,7 +60,7 @@ pub async fn handle_user(
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -8,18 +8,17 @@ use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use tracing::{error, info};
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Could not connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// [`tokio_postgres::error::Kind`] doesn't contain ip addresses and such.
|
||||
#[error("Failed to connect to the compute node: {0}")]
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("Failed to connect to the compute node")]
|
||||
FailedToConnectToCompute,
|
||||
|
||||
#[error("Failed to fetch compute node version")]
|
||||
FailedToFetchPgVersion,
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnectionError {
|
||||
@@ -29,10 +28,10 @@ impl UserFacingError for ConnectionError {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => err.message().to_string(),
|
||||
Some(err) => err.message().to_owned(),
|
||||
None => err.to_string(),
|
||||
},
|
||||
other => other.to_string(),
|
||||
_ => COULD_NOT_CONNECT.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -49,7 +48,7 @@ pub struct ConnCfg(pub tokio_postgres::Config);
|
||||
impl ConnCfg {
|
||||
/// Construct a new connection config.
|
||||
pub fn new() -> Self {
|
||||
Self(tokio_postgres::Config::new())
|
||||
Self(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +94,7 @@ impl ConnCfg {
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
"couldn't connect: bad compute config, \
|
||||
ports and hosts entries' count does not match: {:?}",
|
||||
ports and hosts entries' count does not match: {:?}",
|
||||
self.0
|
||||
),
|
||||
));
|
||||
@@ -131,8 +130,8 @@ impl ConnCfg {
|
||||
pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: TcpStream,
|
||||
/// PostgreSQL version of this instance.
|
||||
pub version: String,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
@@ -156,6 +155,7 @@ impl ConnCfg {
|
||||
self.0.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
@@ -172,22 +172,24 @@ impl ConnCfg {
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
//
|
||||
// This and the reverse params problem can be better addressed
|
||||
// in a bespoke connection machinery (a new library for that sake).
|
||||
|
||||
let (socket_addr, mut stream) = self
|
||||
.connect_raw()
|
||||
.await
|
||||
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
|
||||
|
||||
// TODO: establish a secure connection to the DB
|
||||
let (client, conn) = self.0.connect_raw(&mut stream, NoTls).await?;
|
||||
let version = conn
|
||||
.parameter("server_version")
|
||||
.ok_or(ConnectionError::FailedToFetchPgVersion)?
|
||||
.into();
|
||||
|
||||
// TODO: establish a secure connection to the DB.
|
||||
let (socket_addr, mut stream) = self.connect_raw().await?;
|
||||
let (client, connection) = self.0.connect_raw(&mut stream, NoTls).await?;
|
||||
info!("connected to user's compute node at {socket_addr}");
|
||||
|
||||
// This is very ugly but as of now there's no better way to
|
||||
// extract the connection parameters from tokio-postgres' connection.
|
||||
// TODO: solve this problem in a more elegant manner (e.g. the new library).
|
||||
let params = connection.parameters;
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
let db = PostgresConnection { stream, version };
|
||||
let db = PostgresConnection { stream, params };
|
||||
|
||||
Ok((db, cancel_closure))
|
||||
}
|
||||
|
||||
@@ -255,15 +255,21 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
// Note that we do this only (for the most part) after we've connected
|
||||
// to a compute (see above) which performs its own authentication.
|
||||
if !auth_result.reported_auth_ok {
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
// Right now the implementation is very hacky and inefficent (ideally,
|
||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||
for (name, value) in &db.params {
|
||||
// TODO: Theoretically, this could result in a big pile of params...
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&db.version),
|
||||
))?
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
@@ -139,7 +139,7 @@ async fn dummy_proxy(
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ hyper = "0.14"
|
||||
nix = "0.25"
|
||||
once_cell = "1.13.0"
|
||||
parking_lot = "0.12.1"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
regex = "1.4.5"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
@@ -29,7 +29,7 @@ serde_with = "2.0"
|
||||
signal-hook = "0.3.10"
|
||||
thiserror = "1"
|
||||
tokio = { version = "1.17", features = ["macros", "fs"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="69c1ef71cd5418cf063d4ca21eadc3427980caea" }
|
||||
toml_edit = { version = "0.14", features = ["easy"] }
|
||||
tracing = "0.1.27"
|
||||
url = "2.2.2"
|
||||
|
||||
@@ -122,3 +122,33 @@ def test_auth_errors(static_proxy: NeonProxy):
|
||||
# Finally, check that the user can connect
|
||||
with static_proxy.connect(user="pinocchio", password="magic", options="project=irrelevant"):
|
||||
pass
|
||||
|
||||
|
||||
def test_forward_params_to_client(static_proxy: NeonProxy):
|
||||
# A subset of parameters (GUCs) which postgres
|
||||
# sends to the client during connection setup.
|
||||
# Unfortunately, `GUC_REPORT` can't be queried.
|
||||
# Proxy *should* forward them, otherwise client library
|
||||
# might misbehave (e.g. parse timestamps incorrectly).
|
||||
reported_params_subset = [
|
||||
"client_encoding",
|
||||
"integer_datetimes",
|
||||
"is_superuser",
|
||||
"server_encoding",
|
||||
"server_version",
|
||||
"session_authorization",
|
||||
"standard_conforming_strings",
|
||||
]
|
||||
|
||||
query = """
|
||||
select name, setting
|
||||
from pg_catalog.pg_settings
|
||||
where name = any(%s)
|
||||
"""
|
||||
|
||||
with static_proxy.connect(options="project=irrelevant") as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query, (reported_params_subset,))
|
||||
for name, value in cur.fetchall():
|
||||
# Check that proxy has forwarded this parameter.
|
||||
assert conn.get_parameter_status(name) == value
|
||||
|
||||
Reference in New Issue
Block a user