Forward various connection params to compute nodes. (#2336)

Previously, proxy didn't forward auxiliary `options` parameter
and other ones to the client's compute node, e.g.

```
$ psql "user=john host=localhost dbname=postgres options='-cgeqo=off'"
postgres=# show geqo;
┌──────┐
│ geqo │
├──────┤
│ on   │
└──────┘
(1 row)
```

With this patch we now forward `options`, `application_name` and `replication`.

Further reading: https://www.postgresql.org/docs/current/libpq-connect.html

Fixes #1287.
This commit is contained in:
Dmitry Ivanov
2022-08-30 17:36:21 +03:00
committed by GitHub
parent 60408db101
commit 96a50e99cf
13 changed files with 271 additions and 127 deletions

1
Cargo.lock generated
View File

@@ -2271,6 +2271,7 @@ dependencies = [
"hex",
"hmac 0.12.1",
"hyper",
"itertools",
"md5",
"metrics",
"once_cell",

View File

@@ -7,11 +7,14 @@ use anyhow::{bail, ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::future::Future;
use std::io::{self, Cursor};
use std::str;
use std::time::{Duration, SystemTime};
use std::{
borrow::Cow,
collections::HashMap,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
@@ -53,7 +56,67 @@ pub enum FeStartupPacket {
},
}
pub type StartupMessageParams = HashMap<String, String>;
#[derive(Debug)]
pub struct StartupMessageParams {
params: HashMap<String, String>,
}
impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str())
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
// See `postgres: pg_split_opts`.
let mut last_was_escape = false;
let iter = self
.get("options")?
.split(move |c: char| {
// We split by non-escaped whitespace symbols.
let should_split = c.is_ascii_whitespace() && !last_was_escape;
last_was_escape = c == '\\' && !last_was_escape;
should_split
})
.filter(|s| !s.is_empty());
Some(iter)
}
/// Split command-line options according to PostgreSQL's logic,
/// applying all escape sequences (using owned strings as needed).
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
// See `postgres: pg_split_opts`.
let iter = self.options_raw()?.map(|s| {
let mut preserve_next_escape = false;
let escape = |c| {
// We should remove '\\' unless it's preceded by '\\'.
let should_remove = c == '\\' && !preserve_next_escape;
preserve_next_escape = should_remove;
should_remove
};
match s.contains('\\') {
true => Cow::Owned(s.replace(escape, "")),
false => Cow::Borrowed(s),
}
});
Some(iter)
}
// This function is mostly useful in tests.
#[doc(hidden)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
Self {
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
}
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
@@ -237,9 +300,9 @@ impl FeStartupPacket {
stream.read_exact(params_bytes.as_mut()).await?;
// Parse params depending on request code
let most_sig_16_bits = request_code >> 16;
let least_sig_16_bits = request_code & ((1 << 16) - 1);
let message = match (most_sig_16_bits, least_sig_16_bits) {
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
@@ -248,49 +311,44 @@ impl FeStartupPacket {
cancel_key: cursor.read_i32().await?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// TODO bail if protocol major_version is not 3?
// Parse null-terminated (String) pairs of param name / param value
let params_str = str::from_utf8(&params_bytes).unwrap();
let mut params_tokens = params_str.split('\0');
let mut params: HashMap<String, String> = HashMap::new();
while let Some(name) = params_tokens.next() {
let value = params_tokens
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null terminator
.context("StartupMessage params: missing null terminator")?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens
.next()
.context("expected even number of params in StartupMessage")?;
if name == "options" {
// parsing options arguments "...&options=<var0>%3D<val0>+<var1>=<var1>..."
// '%3D' is '=' and '+' is ' '
.context("StartupMessage params: key without value")?;
// Note: we allow users that don't have SNI capabilities,
// to pass a special keyword argument 'project'
// to be used to determine the cluster name by the proxy.
//TODO: write unit test for this and refactor in its own function.
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params.insert(name.to_string(), value.to_string());
}
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
@@ -967,6 +1025,33 @@ mod tests {
assert_eq!(zf, zf_parsed);
}
#[test]
fn test_startup_message_params_options_escaped() {
fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
params
.options_escaped()
.expect("options are None")
.collect()
}
let make_params = |options| StartupMessageParams::new([("options", options)]);
let params = StartupMessageParams::new([]);
assert!(matches!(params.options_escaped(), None));
let params = make_params("");
assert!(split_options(&params).is_empty());
let params = make_params("foo");
assert_eq!(split_options(&params), ["foo"]);
let params = make_params(" foo bar ");
assert_eq!(split_options(&params), ["foo", "bar"]);
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());

View File

@@ -15,6 +15,7 @@ hashbrown = "0.12"
hex = "0.4.3"
hmac = "0.12.1"
hyper = "0.14"
itertools = "0.10.3"
once_cell = "1.13.0"
md5 = "0.7.0"
parking_lot = "0.12"

View File

@@ -127,7 +127,7 @@ impl<T, E> BackendType<Result<T, E>> {
}
}
impl BackendType<ClientCredentials> {
impl BackendType<ClientCredentials<'_>> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
@@ -149,7 +149,7 @@ impl BackendType<ClientCredentials> {
// Finally we may finish the initialization of `creds`.
// TODO: add missing type safety to ClientCredentials.
creds.project = Some(payload.project);
creds.project = Some(payload.project.into());
let mut config = match &self {
Console(creds) => {

View File

@@ -121,7 +121,7 @@ pub enum AuthInfo {
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}
impl<'a> Api<'a> {
@@ -143,7 +143,7 @@ impl<'a> Api<'a> {
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", &self.creds.user);
.append_pair("role", self.creds.user);
// TODO: use a proper logger
println!("cplane request: {url}");
@@ -187,8 +187,8 @@ impl<'a> Api<'a> {
config
.host(host)
.port(port)
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
}

View File

@@ -56,7 +56,7 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl ClientCredentials {
impl ClientCredentials<'_> {
fn is_existing_user(&self) -> bool {
self.user.ends_with("@zenith")
}
@@ -64,15 +64,15 @@ impl ClientCredentials {
async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<DatabaseInfo, LegacyAuthError> {
let mut url = auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("login", creds.user)
.append_pair("database", creds.dbname)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
@@ -103,7 +103,7 @@ async fn authenticate_proxy_client(
async fn handle_existing_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
) -> auth::Result<compute::NodeInfo> {
let psql_session_id = super::link::new_psql_session_id();
let md5_salt = rand::random();
@@ -136,7 +136,7 @@ async fn handle_existing_user(
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {

View File

@@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
@@ -87,8 +87,8 @@ impl<'a> Api<'a> {
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
}

View File

@@ -1,6 +1,7 @@
//! User credentials used in authentication.
use crate::error::UserFacingError;
use std::borrow::Cow;
use thiserror::Error;
use utils::pq_proto::StartupMessageParams;
@@ -27,51 +28,59 @@ impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub project: Option<String>,
pub struct ClientCredentials<'a> {
pub user: &'a str,
pub dbname: &'a str,
pub project: Option<Cow<'a, str>>,
}
impl ClientCredentials {
impl ClientCredentials<'_> {
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
}
}
impl ClientCredentials {
impl<'a> ClientCredentials<'a> {
pub fn parse(
mut options: StartupMessageParams,
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
// Some parameters are absolutely necessary, others not so much.
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let user = get_param("user")?;
let dbname = get_param("database")?;
let project_a = get_param("project").ok();
// Project name might be passed via PG's command-line options.
let project_a = params.options_raw().and_then(|options| {
for opt in options {
if let Some(value) = opt.strip_prefix("project=") {
return Some(Cow::Borrowed(value));
}
}
None
});
// Alternative project name is in fact a subdomain from SNI.
// NOTE: we do not consider SNI if `common_name` is missing.
let project_b = sni
.zip(common_name)
.map(|(sni, cn)| {
// TODO: what if SNI is present but just a common name?
subdomain_from_sni(sni, cn)
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
.ok_or_else(|| InconsistentSni(sni.into(), cn.into()))
.map(Cow::<'static, str>::Owned)
})
.transpose()?;
let project = match (project_a, project_b) {
// Invariant: if we have both project name variants, they should match.
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
(a, b) => a.or(b).map(|name| {
// Invariant: project name may not contain certain characters.
check_project_name(name).map_err(MalformedProjectName)
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a.into(), b.into()))),
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
false => Err(MalformedProjectName(name.into())),
true => Ok(name),
}),
}
.transpose()?;
@@ -84,12 +93,8 @@ impl ClientCredentials {
}
}
fn check_project_name(name: String) -> Result<String, String> {
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
Ok(name)
} else {
Err(name)
}
fn project_name_valid(name: &str) -> bool {
name.chars().all(|c| c.is_alphanumeric() || c == '-')
}
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
@@ -102,18 +107,14 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
mod tests {
use super::*;
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
}
#[test]
#[ignore = "TODO: fix how database is handled"]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = make_options([("user", "john_doe")]);
let options = StartupMessageParams::new([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
Ok(())
@@ -121,9 +122,9 @@ mod tests {
#[test]
fn parse_missing_project() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
@@ -133,12 +134,12 @@ mod tests {
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -148,13 +149,13 @@ mod tests {
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "bar"),
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -164,16 +165,16 @@ mod tests {
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "baz"),
("options", "project=baz"),
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -183,17 +184,17 @@ mod tests {
#[test]
fn parse_projects_different() {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "first"),
("options", "project=first"),
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
assert!(matches!(
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"),
ClientCredsParseError::InconsistentProjectNames(_, _)
));
}

View File

@@ -95,7 +95,7 @@ impl<'a> Session<'a> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
self.cancel_map
.0
.lock()

View File

@@ -1,9 +1,11 @@
use crate::{cancellation::CancelClosure, error::UserFacingError};
use futures::TryFutureExt;
use itertools::Itertools;
use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
use utils::pq_proto::StartupMessageParams;
#[derive(Debug, Error)]
pub enum ConnectionError {
@@ -110,7 +112,42 @@ pub struct PostgresConnection {
impl NodeInfo {
/// Connect to a corresponding compute node.
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
pub async fn connect(
mut self,
params: &StartupMessageParams,
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
if let Some(options) = params.options_raw() {
// We must drop all proxy-specific parameters.
#[allow(unstable_name_collisions)]
let options: String = options
.filter(|opt| !opt.starts_with("project="))
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
self.config.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.config.application_name(app_name);
}
if let Some(replication) = params.get("replication") {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.config.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.config.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
}
// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
let (socket_addr, mut stream) = self
.connect_raw()
.await

View File

@@ -1,6 +1,6 @@
use crate::auth;
use crate::cancellation::{self, CancelMap};
use crate::config::{ProxyConfig, TlsConfig};
use crate::config::{AuthUrls, ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -93,20 +93,21 @@ async fn handle_client(
None => return Ok(()), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds);
let client = Client::new(stream, creds, &params);
cancel_map
.with_session(|session| client.connect_to_db(config, session))
.with_session(|session| client.connect_to_db(&config.auth_urls, session))
.await
}
@@ -174,38 +175,57 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
/// Thin connection context.
struct Client<S> {
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::BackendType<auth::ClientCredentials>,
creds: auth::BackendType<auth::ClientCredentials<'a>>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
}
impl<S> Client<S> {
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
Self { stream, creds }
fn new(
stream: PqStream<S>,
creds: auth::BackendType<auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
) -> Self {
Self {
stream,
creds,
params,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,
config: &ProxyConfig,
urls: &AuthUrls,
session: cancellation::Session<'_>,
) -> anyhow::Result<()> {
let Self { mut stream, creds } = self;
let Self {
mut stream,
creds,
params,
} = self;
// Authenticate and connect to a compute node.
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
let auth = creds.authenticate(urls, &mut stream).await;
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let reported_auth_ok = node.reported_auth_ok;
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
let (db, cancel_closure) = node
.connect(params)
.or_else(|e| stream.throw_error(e))
.await?;
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
// Report authentication success if we haven't done this already.
if !node.reported_auth_ok {
if !reported_auth_ok {
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;

View File

@@ -11,7 +11,6 @@ use anyhow::{bail, Context, Result};
use postgres_ffi::PG_TLI;
use regex::Regex;
use std::str::FromStr;
use std::sync::Arc;
use tracing::info;
use utils::{
@@ -67,18 +66,22 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
// ztenant id and ztimeline id are passed in connection string params
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> {
if let FeStartupPacket::StartupMessage { params, .. } = sm {
self.ztenantid = match params.get("ztenantid") {
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
_ => None,
};
self.ztimelineid = match params.get("ztimelineid") {
Some(z) => Some(ZTimelineId::from_str(z)?),
_ => None,
};
if let Some(options) = params.options_raw() {
for opt in options {
match opt.split_once('=') {
Some(("ztenantid", value)) => {
self.ztenantid = Some(value.parse()?);
}
Some(("ztimelineid", value)) => {
self.ztimelineid = Some(value.parse()?);
}
_ => continue,
}
}
}
if let Some(app_name) = params.get("application_name") {
self.appname = Some(app_name.clone());
self.appname = Some(app_name.to_owned());
}
Ok(())

View File

@@ -134,12 +134,8 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx
# Pass extra options to the server.
#
# Currently, proxy eats the extra connection options, so this fails.
# See https://github.com/neondatabase/neon/issues/1287
@pytest.mark.xfail
def test_proxy_options(static_proxy):
with static_proxy.connect(options="-cproxytest.option=value") as conn:
with static_proxy.connect(options="project=irrelevant -cproxytest.option=value") as conn:
with conn.cursor() as cur:
cur.execute("SHOW proxytest.option")
value = cur.fetchall()[0][0]