mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 13:02:55 +00:00
Forward various connection params to compute nodes. (#2336)
Previously, proxy didn't forward auxiliary `options` parameter and other ones to the client's compute node, e.g. ``` $ psql "user=john host=localhost dbname=postgres options='-cgeqo=off'" postgres=# show geqo; ┌──────┐ │ geqo │ ├──────┤ │ on │ └──────┘ (1 row) ``` With this patch we now forward `options`, `application_name` and `replication`. Further reading: https://www.postgresql.org/docs/current/libpq-connect.html Fixes #1287.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2271,6 +2271,7 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac 0.12.1",
|
||||
"hyper",
|
||||
"itertools",
|
||||
"md5",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
|
||||
@@ -7,11 +7,14 @@ use anyhow::{bail, ensure, Context, Result};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io::{self, Cursor};
|
||||
use std::str;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
io::{self, Cursor},
|
||||
str,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
@@ -53,7 +56,67 @@ pub enum FeStartupPacket {
|
||||
},
|
||||
}
|
||||
|
||||
pub type StartupMessageParams = HashMap<String, String>;
|
||||
#[derive(Debug)]
|
||||
pub struct StartupMessageParams {
|
||||
params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Get parameter's value by its name.
|
||||
pub fn get(&self, name: &str) -> Option<&str> {
|
||||
self.params.get(name).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// taking into account all escape sequences but leaving them as-is.
|
||||
/// [`None`] means that there's no `options` in [`Self`].
|
||||
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
|
||||
// See `postgres: pg_split_opts`.
|
||||
let mut last_was_escape = false;
|
||||
let iter = self
|
||||
.get("options")?
|
||||
.split(move |c: char| {
|
||||
// We split by non-escaped whitespace symbols.
|
||||
let should_split = c.is_ascii_whitespace() && !last_was_escape;
|
||||
last_was_escape = c == '\\' && !last_was_escape;
|
||||
should_split
|
||||
})
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
Some(iter)
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// applying all escape sequences (using owned strings as needed).
|
||||
/// [`None`] means that there's no `options` in [`Self`].
|
||||
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
|
||||
// See `postgres: pg_split_opts`.
|
||||
let iter = self.options_raw()?.map(|s| {
|
||||
let mut preserve_next_escape = false;
|
||||
let escape = |c| {
|
||||
// We should remove '\\' unless it's preceded by '\\'.
|
||||
let should_remove = c == '\\' && !preserve_next_escape;
|
||||
preserve_next_escape = should_remove;
|
||||
should_remove
|
||||
};
|
||||
|
||||
match s.contains('\\') {
|
||||
true => Cow::Owned(s.replace(escape, "")),
|
||||
false => Cow::Borrowed(s),
|
||||
}
|
||||
});
|
||||
|
||||
Some(iter)
|
||||
}
|
||||
|
||||
// This function is mostly useful in tests.
|
||||
#[doc(hidden)]
|
||||
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
|
||||
Self {
|
||||
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct CancelKeyData {
|
||||
@@ -237,9 +300,9 @@ impl FeStartupPacket {
|
||||
stream.read_exact(params_bytes.as_mut()).await?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let most_sig_16_bits = request_code >> 16;
|
||||
let least_sig_16_bits = request_code & ((1 << 16) - 1);
|
||||
let message = match (most_sig_16_bits, least_sig_16_bits) {
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
@@ -248,49 +311,44 @@ impl FeStartupPacket {
|
||||
cancel_key: cursor.read_i32().await?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
bail!("Unrecognized request code {}", unrecognized_code)
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
// Parse null-terminated (String) pairs of param name / param value
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let mut params_tokens = params_str.split('\0');
|
||||
let mut params: HashMap<String, String> = HashMap::new();
|
||||
while let Some(name) = params_tokens.next() {
|
||||
let value = params_tokens
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null terminator
|
||||
.context("StartupMessage params: missing null terminator")?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens
|
||||
.next()
|
||||
.context("expected even number of params in StartupMessage")?;
|
||||
if name == "options" {
|
||||
// parsing options arguments "...&options=<var0>%3D<val0>+<var1>=<var1>..."
|
||||
// '%3D' is '=' and '+' is ' '
|
||||
.context("StartupMessage params: key without value")?;
|
||||
|
||||
// Note: we allow users that don't have SNI capabilities,
|
||||
// to pass a special keyword argument 'project'
|
||||
// to be used to determine the cluster name by the proxy.
|
||||
|
||||
//TODO: write unit test for this and refactor in its own function.
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
params.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1025,33 @@ mod tests {
|
||||
assert_eq!(zf, zf_parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_startup_message_params_options_escaped() {
|
||||
fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
|
||||
params
|
||||
.options_escaped()
|
||||
.expect("options are None")
|
||||
.collect()
|
||||
}
|
||||
|
||||
let make_params = |options| StartupMessageParams::new([("options", options)]);
|
||||
|
||||
let params = StartupMessageParams::new([]);
|
||||
assert!(matches!(params.options_escaped(), None));
|
||||
|
||||
let params = make_params("");
|
||||
assert!(split_options(¶ms).is_empty());
|
||||
|
||||
let params = make_params("foo");
|
||||
assert_eq!(split_options(¶ms), ["foo"]);
|
||||
|
||||
let params = make_params(" foo bar ");
|
||||
assert_eq!(split_options(¶ms), ["foo", "bar"]);
|
||||
|
||||
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
|
||||
assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
|
||||
}
|
||||
|
||||
// Make sure that `read` is sync/async callable
|
||||
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
|
||||
let _ = FeMessage::read(&mut [].as_ref());
|
||||
|
||||
@@ -15,6 +15,7 @@ hashbrown = "0.12"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.12.1"
|
||||
hyper = "0.14"
|
||||
itertools = "0.10.3"
|
||||
once_cell = "1.13.0"
|
||||
md5 = "0.7.0"
|
||||
parking_lot = "0.12"
|
||||
|
||||
@@ -127,7 +127,7 @@ impl<T, E> BackendType<Result<T, E>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<ClientCredentials> {
|
||||
impl BackendType<ClientCredentials<'_>> {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
@@ -149,7 +149,7 @@ impl BackendType<ClientCredentials> {
|
||||
|
||||
// Finally we may finish the initialization of `creds`.
|
||||
// TODO: add missing type safety to ClientCredentials.
|
||||
creds.project = Some(payload.project);
|
||||
creds.project = Some(payload.project.into());
|
||||
|
||||
let mut config = match &self {
|
||||
Console(creds) => {
|
||||
|
||||
@@ -121,7 +121,7 @@ pub enum AuthInfo {
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
@@ -143,7 +143,7 @@ impl<'a> Api<'a> {
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project().expect("impossible"))
|
||||
.append_pair("role", &self.creds.user);
|
||||
.append_pair("role", self.creds.user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
@@ -187,8 +187,8 @@ impl<'a> Api<'a> {
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ enum ProxyAuthResponse {
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ClientCredentials<'_> {
|
||||
fn is_existing_user(&self) -> bool {
|
||||
self.user.ends_with("@zenith")
|
||||
}
|
||||
@@ -64,15 +64,15 @@ impl ClientCredentials {
|
||||
|
||||
async fn authenticate_proxy_client(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
md5_response: &str,
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<DatabaseInfo, LegacyAuthError> {
|
||||
let mut url = auth_endpoint.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("login", creds.user)
|
||||
.append_pair("database", creds.dbname)
|
||||
.append_pair("md5response", md5_response)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
@@ -103,7 +103,7 @@ async fn authenticate_proxy_client(
|
||||
async fn handle_existing_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
let psql_session_id = super::link::new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
@@ -136,7 +136,7 @@ async fn handle_existing_user(
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
auth_link_uri: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
if creds.is_existing_user() {
|
||||
|
||||
@@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
@@ -87,8 +87,8 @@ impl<'a> Api<'a> {
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
use std::borrow::Cow;
|
||||
use thiserror::Error;
|
||||
use utils::pq_proto::StartupMessageParams;
|
||||
|
||||
@@ -27,51 +28,59 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
pub project: Option<String>,
|
||||
pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ClientCredentials<'_> {
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl<'a> ClientCredentials<'a> {
|
||||
pub fn parse(
|
||||
mut options: StartupMessageParams,
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
// Some parameters are absolutely necessary, others not so much.
|
||||
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user = get_param("user")?;
|
||||
let dbname = get_param("database")?;
|
||||
let project_a = get_param("project").ok();
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_a = params.options_raw().and_then(|options| {
|
||||
for opt in options {
|
||||
if let Some(value) = opt.strip_prefix("project=") {
|
||||
return Some(Cow::Borrowed(value));
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
// Alternative project name is in fact a subdomain from SNI.
|
||||
// NOTE: we do not consider SNI if `common_name` is missing.
|
||||
let project_b = sni
|
||||
.zip(common_name)
|
||||
.map(|(sni, cn)| {
|
||||
// TODO: what if SNI is present but just a common name?
|
||||
subdomain_from_sni(sni, cn)
|
||||
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
|
||||
.ok_or_else(|| InconsistentSni(sni.into(), cn.into()))
|
||||
.map(Cow::<'static, str>::Owned)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let project = match (project_a, project_b) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
|
||||
(a, b) => a.or(b).map(|name| {
|
||||
// Invariant: project name may not contain certain characters.
|
||||
check_project_name(name).map_err(MalformedProjectName)
|
||||
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a.into(), b.into()))),
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
|
||||
false => Err(MalformedProjectName(name.into())),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
@@ -84,12 +93,8 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_project_name(name: String) -> Result<String, String> {
|
||||
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
|
||||
Ok(name)
|
||||
} else {
|
||||
Err(name)
|
||||
}
|
||||
fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
@@ -102,18 +107,14 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
|
||||
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "TODO: fix how database is handled"]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = make_options([("user", "john_doe")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
@@ -121,9 +122,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -133,12 +134,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_sni() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -148,13 +149,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_options() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "bar"),
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -164,16 +165,16 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_projects_identical() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "baz"),
|
||||
("options", "project=baz"),
|
||||
]);
|
||||
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -183,17 +184,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_projects_different() {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "first"),
|
||||
("options", "project=first"),
|
||||
]);
|
||||
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
assert!(matches!(
|
||||
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
|
||||
ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"),
|
||||
ClientCredsParseError::InconsistentProjectNames(_, _)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ impl<'a> Session<'a> {
|
||||
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::{cancellation::CancelClosure, error::UserFacingError};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use std::{io, net::SocketAddr};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use utils::pq_proto::StartupMessageParams;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConnectionError {
|
||||
@@ -110,7 +112,42 @@ pub struct PostgresConnection {
|
||||
|
||||
impl NodeInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
if let Some(options) = params.options_raw() {
|
||||
// We must drop all proxy-specific parameters.
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = options
|
||||
.filter(|opt| !opt.starts_with("project="))
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
self.config.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.config.application_name(app_name);
|
||||
}
|
||||
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.config.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.config.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
|
||||
let (socket_addr, mut stream) = self
|
||||
.connect_raw()
|
||||
.await
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelMap};
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::config::{AuthUrls, ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
@@ -93,20 +93,21 @@ async fn handle_client(
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds);
|
||||
let client = Client::new(stream, creds, ¶ms);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(config, session))
|
||||
.with_session(|session| client.connect_to_db(&config.auth_urls, session))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -174,38 +175,57 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<S> {
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<auth::ClientCredentials>,
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
}
|
||||
|
||||
impl<S> Client<S> {
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
|
||||
Self { stream, creds }
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
urls: &AuthUrls,
|
||||
session: cancellation::Session<'_>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Self { mut stream, creds } = self;
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
} = self;
|
||||
|
||||
// Authenticate and connect to a compute node.
|
||||
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
|
||||
let auth = creds.authenticate(urls, &mut stream).await;
|
||||
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
|
||||
let reported_auth_ok = node.reported_auth_ok;
|
||||
|
||||
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
|
||||
let cancel_key_data = session.enable_cancellation(cancel_closure);
|
||||
let (db, cancel_closure) = node
|
||||
.connect(params)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
if !node.reported_auth_ok {
|
||||
if !reported_auth_ok {
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
@@ -11,7 +11,6 @@ use anyhow::{bail, Context, Result};
|
||||
|
||||
use postgres_ffi::PG_TLI;
|
||||
use regex::Regex;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use utils::{
|
||||
@@ -67,18 +66,22 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
// ztenant id and ztimeline id are passed in connection string params
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> {
|
||||
if let FeStartupPacket::StartupMessage { params, .. } = sm {
|
||||
self.ztenantid = match params.get("ztenantid") {
|
||||
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
|
||||
_ => None,
|
||||
};
|
||||
|
||||
self.ztimelineid = match params.get("ztimelineid") {
|
||||
Some(z) => Some(ZTimelineId::from_str(z)?),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(options) = params.options_raw() {
|
||||
for opt in options {
|
||||
match opt.split_once('=') {
|
||||
Some(("ztenantid", value)) => {
|
||||
self.ztenantid = Some(value.parse()?);
|
||||
}
|
||||
Some(("ztimelineid", value)) => {
|
||||
self.ztimelineid = Some(value.parse()?);
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.appname = Some(app_name.clone());
|
||||
self.appname = Some(app_name.to_owned());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -134,12 +134,8 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx
|
||||
|
||||
|
||||
# Pass extra options to the server.
|
||||
#
|
||||
# Currently, proxy eats the extra connection options, so this fails.
|
||||
# See https://github.com/neondatabase/neon/issues/1287
|
||||
@pytest.mark.xfail
|
||||
def test_proxy_options(static_proxy):
|
||||
with static_proxy.connect(options="-cproxytest.option=value") as conn:
|
||||
with static_proxy.connect(options="project=irrelevant -cproxytest.option=value") as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW proxytest.option")
|
||||
value = cur.fetchall()[0][0]
|
||||
|
||||
Reference in New Issue
Block a user