feat(proxy): add option to forward startup params (#9979)

(stacked on #9990 and #9995)

Partially fixes #1287 with a custom option field to enable the fixed
behaviour. This allows us to gradually roll out the fix without silently
changing the observed behaviour for our customers.

related to https://github.com/neondatabase/cloud/issues/15284
This commit is contained in:
Conrad Ludgate
2024-12-04 12:58:35 +00:00
committed by Ivan Efremov
parent 6359342ffb
commit cab498c787
19 changed files with 180 additions and 340 deletions

4
Cargo.lock generated
View File

@@ -1031,9 +1031,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "1.5.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
dependencies = [
"serde",
]

View File

@@ -74,7 +74,7 @@ bindgen = "0.70"
bit_field = "0.10.2"
bstr = "1.0"
byteorder = "1.4"
bytes = "1.0"
bytes = "1.9"
camino = "1.1.6"
cfg-if = "1.0.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }

View File

@@ -100,7 +100,7 @@ impl StartupMessageParamsBuilder {
#[derive(Debug, Clone, Default)]
pub struct StartupMessageParams {
params: Bytes,
pub params: Bytes,
}
impl StartupMessageParams {

View File

@@ -117,7 +117,7 @@ enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(Box<ScramKeys<N>>),
Keys(ScramKeys<N>),
}
enum State {
@@ -176,7 +176,7 @@ impl ScramSha256 {
/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys.into());
let password = Credentials::Keys(keys);
ScramSha256::new_inner(password, channel_binding, nonce())
}

View File

@@ -255,22 +255,34 @@ pub fn ssl_request(buf: &mut BytesMut) {
}
#[inline]
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = (&'a str, &'a str)>,
{
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
for (key, value) in parameters {
write_cstr(key.as_bytes(), buf)?;
write_cstr(value.as_bytes(), buf)?;
}
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StartupMessageParams {
pub params: BytesMut,
}
impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') || value.contains('\0') {
panic!("startup parameter name or value contained a null")
}
self.params.put_slice(name.as_bytes());
self.params.put_u8(0);
self.params.put_slice(value.as_bytes());
self.params.put_u8(0);
}
}
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');

View File

@@ -35,9 +35,7 @@ impl FallibleIterator for BackendMessages {
}
}
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
@@ -66,15 +64,6 @@ impl Decoder for PostgresCodec {
break;
}
if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}
match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG

View File

@@ -6,6 +6,7 @@ use crate::connect_raw::RawConnection;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
use postgres_protocol2::message::frontend::StartupMessageParams;
use std::fmt;
use std::str;
use std::time::Duration;
@@ -14,16 +15,6 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use tokio::net::TcpStream;
/// Properties required of a session.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum TargetSessionAttrs {
/// No special properties are required.
Any,
/// The session must allow writes.
ReadWrite,
}
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
@@ -73,94 +64,20 @@ pub enum AuthKeys {
}
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
/// # Key-Value
///
/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
///
/// ## Keys
///
/// * `user` - The username to authenticate with. Required.
/// * `password` - The password to authenticate with.
/// * `dbname` - The name of the database to connect to. Defaults to the username.
/// * `options` - Command line options used to configure the server.
/// * `application_name` - Sets the `application_name` parameter on the server.
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
/// with the `connect` method.
/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
/// omitted or the empty string.
/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
///
/// ## Examples
///
/// ```not_rust
/// host=localhost user=postgres connect_timeout=10 keepalives=0
/// ```
///
/// ```not_rust
/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
/// ```
///
/// ```not_rust
/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
/// ```
///
/// # Url
///
/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
/// as the path component of the URL specifies the database name.
///
/// ## Examples
///
/// ```not_rust
/// postgresql://user@localhost
/// ```
///
/// ```not_rust
/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
/// ```
///
/// ```not_rust
/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
/// ```
///
/// ```not_rust
/// postgresql:///mydb?user=user&host=/var/lib/postgresql
/// ```
#[derive(Clone, PartialEq, Eq)]
pub struct Config {
pub(crate) host: Host,
pub(crate) port: u16,
pub(crate) user: Option<String>,
pub(crate) password: Option<Vec<u8>>,
pub(crate) auth_keys: Option<Box<AuthKeys>>,
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) replication_mode: Option<ReplicationMode>,
pub(crate) max_backend_message_size: Option<usize>,
pub(crate) server_params: StartupMessageParams,
database: bool,
username: bool,
}
impl Config {
@@ -169,18 +86,15 @@ impl Config {
Config {
host: Host::Tcp(host),
port,
user: None,
password: None,
auth_keys: None,
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
connect_timeout: None,
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
replication_mode: None,
max_backend_message_size: None,
server_params: StartupMessageParams::default(),
database: false,
username: false,
}
}
@@ -188,14 +102,13 @@ impl Config {
///
/// Required.
pub fn user(&mut self, user: &str) -> &mut Config {
self.user = Some(user.to_string());
self
self.set_param("user", user)
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.user.as_deref()
pub fn user_is_set(&self) -> bool {
self.username
}
/// Sets the password to authenticate with.
@@ -231,40 +144,26 @@ impl Config {
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.dbname = Some(dbname.to_string());
self
self.set_param("database", dbname)
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.dbname.as_deref()
pub fn db_is_set(&self) -> bool {
self.database
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.options = Some(options.to_string());
pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
if name == "database" {
self.database = true;
} else if name == "user" {
self.username = true;
}
self.server_params.insert(name, value);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.options.as_deref()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.application_name = Some(application_name.to_string());
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.application_name.as_deref()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
@@ -303,23 +202,6 @@ impl Config {
self.connect_timeout.as_ref()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.target_session_attrs = target_session_attrs;
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.target_session_attrs
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
@@ -333,28 +215,6 @@ impl Config {
self.channel_binding
}
/// Set replication mode.
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
self.replication_mode = Some(replication_mode);
self
}
/// Get replication mode.
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
self.replication_mode
}
/// Set limit for backend messages size.
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
self.max_backend_message_size = Some(max_backend_message_size);
self
}
/// Get limit for backend messages size.
pub fn get_max_backend_message_size(&self) -> Option<usize> {
self.max_backend_message_size
}
/// Opens a connection to a PostgreSQL database.
///
/// Requires the `runtime` Cargo feature (enabled by default).
@@ -392,18 +252,13 @@ impl fmt::Debug for Config {
}
f.debug_struct("Config")
.field("user", &self.user)
.field("password", &self.password.as_ref().map(|_| Redaction {}))
.field("dbname", &self.dbname)
.field("options", &self.options)
.field("application_name", &self.application_name)
.field("ssl_mode", &self.ssl_mode)
.field("host", &self.host)
.field("port", &self.port)
.field("connect_timeout", &self.connect_timeout)
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.field("replication", &self.replication_mode)
.field("server_params", &self.server_params)
.finish()
}
}

View File

@@ -1,14 +1,11 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::{Host, TargetSessionAttrs};
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use crate::{Client, Config, Connection, Error, RawConnection};
use postgres_protocol2::message::backend::Message;
use std::io;
use std::task::Poll;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
@@ -72,47 +69,7 @@ where
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let mut connection = Connection::new(stream, delayed, parameters, receiver);
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query_raw("SHOW transaction_read_only");
pin_mut!(rows);
let rows = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Err(Error::closed()));
}
rows.as_mut().poll(cx)
})
.await?;
pin_mut!(rows);
loop {
let next = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Some(Err(Error::closed())));
}
rows.as_mut().poll_next(cx)
});
match next.await.transpose()? {
Some(SimpleQueryMessage::Row(row)) => {
if row.try_get(0)? == Some("on") {
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else {
break;
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
}
}
}
let connection = Connection::new(stream, delayed, parameters, receiver);
Ok((client, connection))
}

View File

@@ -1,5 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, AuthKeys, Config, ReplicationMode};
use crate::config::{self, AuthKeys, Config};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
@@ -96,12 +96,7 @@ where
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
let mut stream = StartupStream {
inner: Framed::new(
stream,
PostgresCodec {
max_message_size: config.max_backend_message_size,
},
),
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed_notice: Vec::new(),
};
@@ -124,28 +119,8 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut params = vec![("client_encoding", "UTF8")];
if let Some(user) = &config.user {
params.push(("user", &**user));
}
if let Some(dbname) = &config.dbname {
params.push(("database", &**dbname));
}
if let Some(options) = &config.options {
params.push(("options", &**options));
}
if let Some(application_name) = &config.application_name {
params.push(("application_name", &**application_name));
}
if let Some(replication_mode) = &config.replication_mode {
match replication_mode {
ReplicationMode::Physical => params.push(("replication", "true")),
ReplicationMode::Logical => params.push(("replication", "database")),
}
}
let mut buf = BytesMut::new();
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))

View File

@@ -70,11 +70,12 @@ impl ReportableError for CancelError {
impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
// HACK: We'd rather get the real backend_pid but postgres_client doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
//
// if we forwarded the backend_pid from postgres to the client, there would be a lot
// of overlap between our computes as most pids are small (~100).
let key = loop {
let key = rand::random();

View File

@@ -131,49 +131,37 @@ impl ConnCfg {
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
self.user(user);
pub(crate) fn set_startup_params(
&mut self,
params: &StartupMessageParams,
arbitrary_params: bool,
) {
if !arbitrary_params {
self.set_param("client_encoding", "UTF8");
}
// Only set `dbname` if it's not present in the config.
// Console redirect auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
self.dbname(dbname);
}
// Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(params) {
self.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.application_name(app_name);
}
// TODO: This is especially ugly...
if let Some(replication) = params.get("replication") {
use postgres_client::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical);
for (k, v) in params.iter() {
match k {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
"user" if self.user_is_set() => continue,
"database" if self.db_is_set() => continue,
"options" => {
if let Some(options) = filtered_options(v) {
self.set_param(k, &options);
}
}
"database" => {
self.replication_mode(ReplicationMode::Logical);
"user" | "database" | "application_name" | "replication" => {
self.set_param(k, v);
}
_other => {}
// if we allow arbitrary params, then we forward them through.
// this is a flag for a period of backwards compatibility
k if arbitrary_params => {
self.set_param(k, v);
}
_ => {}
}
}
// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
//
// This and the reverse params problem can be better addressed
// in a bespoke connection machinery (a new library for that sake).
}
}
@@ -347,10 +335,9 @@ impl ConnCfg {
}
/// Retrieve `options` from a startup message, dropping all proxy-secific flags.
fn filtered_options(params: &StartupMessageParams) -> Option<String> {
fn filtered_options(options: &str) -> Option<String> {
#[allow(unstable_name_collisions)]
let options: String = params
.options_raw()?
let options: String = StartupMessageParams::parse_options_raw(options)
.filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
@@ -427,27 +414,24 @@ mod tests {
#[test]
fn test_filtered_options() {
// Empty options is unlikely to be useful anyway.
let params = StartupMessageParams::new([("options", "")]);
assert_eq!(filtered_options(&params), None);
let params = "";
assert_eq!(filtered_options(params), None);
// It's likely that clients will only use options to specify endpoint/project.
let params = StartupMessageParams::new([("options", "project=foo")]);
assert_eq!(filtered_options(&params), None);
let params = "project=foo";
assert_eq!(filtered_options(params), None);
// Same, because unescaped whitespaces are no-op.
let params = StartupMessageParams::new([("options", " project=foo ")]);
assert_eq!(filtered_options(&params).as_deref(), None);
let params = " project=foo ";
assert_eq!(filtered_options(params).as_deref(), None);
let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
assert_eq!(filtered_options(&params).as_deref(), Some(r"\ \ "));
let params = r"\ project=foo \ ";
assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
let params = StartupMessageParams::new([("options", "project = foo")]);
assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
let params = "project = foo";
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
let params = StartupMessageParams::new([(
"options",
"project = foo neon_endpoint_type:read_write neon_lsn:0/2",
)]);
assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
}
}

View File

@@ -206,6 +206,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},

View File

@@ -66,6 +66,8 @@ pub(crate) trait ComputeConnectBackend {
}
pub(crate) struct TcpMechanism<'a> {
pub(crate) params_compat: bool,
/// KV-dictionary with PostgreSQL connection params.
pub(crate) params: &'a StartupMessageParams,
@@ -92,7 +94,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params);
config.set_startup_params(self.params, self.params_compat);
}
}

View File

@@ -338,9 +338,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
};
let params_compat = match &user_info {
auth::Backend::ControlPlane(_, info) => {
info.info.options.get(NeonOptions::PARAMS_COMPAT).is_some()
}
auth::Backend::Local(_) => false,
};
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
@@ -409,19 +417,47 @@ pub(crate) async fn prepare_client_connection<P>(
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
impl NeonOptions {
// proxy options:
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
const PARAMS_COMPAT: &str = "proxy_params_compat";
// cplane options:
/// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN.
const LSN: &str = "lsn";
/// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write.
const ENDPOINT_TYPE: &str = "endpoint_type";
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
params
.options_raw()
.map(Self::parse_from_iter)
.unwrap_or_default()
}
pub(crate) fn parse_options_raw(options: &str) -> Self {
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
}
pub(crate) fn get(&self, key: &str) -> Option<SmolStr> {
self.0
.iter()
.find_map(|(k, v)| (k == key).then_some(v))
.cloned()
}
pub(crate) fn is_ephemeral(&self) -> bool {
// Currently, neon endpoint options are all reserved for ephemeral endpoints.
!self.0.is_empty()
self.0.iter().any(|(k, _)| match &**k {
// This is not a cplane option, we know it does not create ephemeral computes.
Self::PARAMS_COMPAT => false,
Self::LSN => true,
Self::ENDPOINT_TYPE => true,
// err on the side of caution. any cplane options we don't know about
// might lead to ephemeral computes.
_ => true,
})
}
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {

View File

@@ -55,7 +55,13 @@ async fn proxy_mitm(
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(startup.iter(), &mut buf).unwrap();
frontend::startup_message(
&postgres_protocol::message::frontend::StartupMessageParams {
params: startup.params.into(),
},
&mut buf,
)
.unwrap();
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server

View File

@@ -252,7 +252,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
.set_param("options", "project=generic-project-name")
.ssl_mode(SslMode::Prefer)
.connect_raw(server, NoTls)
.await?;

View File

@@ -309,10 +309,13 @@ impl PoolingBackend {
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.options(&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
));
.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(postgres_client::NoTls).await?;

View File

@@ -269,7 +269,7 @@ class PgProtocol:
for match in re.finditer(r"-c(\w*)=(\w*)", options):
key = match.group(1)
val = match.group(2)
if "server_options" in conn_options:
if "server_settings" in conn_options:
conn_options["server_settings"].update({key: val})
else:
conn_options["server_settings"] = {key: val}

View File

@@ -5,6 +5,7 @@ import json
import subprocess
import time
import urllib.parse
from contextlib import closing
from typing import TYPE_CHECKING
import psycopg2
@@ -131,6 +132,24 @@ def test_proxy_options(static_proxy: NeonProxy, option_name: str):
assert out[0][0] == " str"
@pytest.mark.asyncio
async def test_proxy_arbitrary_params(static_proxy: NeonProxy):
with closing(
await static_proxy.connect_async(server_settings={"IntervalStyle": "iso_8601"})
) as conn:
out = await conn.fetchval("select to_json('0 seconds'::interval)")
assert out == '"00:00:00"'
options = "neon_proxy_params_compat:true"
with closing(
await static_proxy.connect_async(
server_settings={"IntervalStyle": "iso_8601", "options": options}
)
) as conn:
out = await conn.fetchval("select to_json('0 seconds'::interval)")
assert out == '"PT0S"'
def test_auth_errors(static_proxy: NeonProxy):
"""
Check that we throw very specific errors in some unsuccessful auth scenarios.