Forward various connection params to compute nodes. (#2336)

Previously, proxy didn't forward auxiliary `options` parameter
and other ones to the client's compute node, e.g.

```
$ psql "user=john host=localhost dbname=postgres options='-cgeqo=off'"
postgres=# show geqo;
┌──────┐
│ geqo │
├──────┤
│ on   │
└──────┘
(1 row)
```

With this patch we now forward `options`, `application_name` and `replication`.

Further reading: https://www.postgresql.org/docs/current/libpq-connect.html

Fixes #1287.
This commit is contained in:
Dmitry Ivanov
2022-08-30 17:36:21 +03:00
committed by GitHub
parent 60408db101
commit 96a50e99cf
13 changed files with 271 additions and 127 deletions

View File

@@ -15,6 +15,7 @@ hashbrown = "0.12"
hex = "0.4.3"
hmac = "0.12.1"
hyper = "0.14"
itertools = "0.10.3"
once_cell = "1.13.0"
md5 = "0.7.0"
parking_lot = "0.12"

View File

@@ -127,7 +127,7 @@ impl<T, E> BackendType<Result<T, E>> {
}
}
impl BackendType<ClientCredentials> {
impl BackendType<ClientCredentials<'_>> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
@@ -149,7 +149,7 @@ impl BackendType<ClientCredentials> {
// Finally we may finish the initialization of `creds`.
// TODO: add missing type safety to ClientCredentials.
creds.project = Some(payload.project);
creds.project = Some(payload.project.into());
let mut config = match &self {
Console(creds) => {

View File

@@ -121,7 +121,7 @@ pub enum AuthInfo {
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}
impl<'a> Api<'a> {
@@ -143,7 +143,7 @@ impl<'a> Api<'a> {
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", &self.creds.user);
.append_pair("role", self.creds.user);
// TODO: use a proper logger
println!("cplane request: {url}");
@@ -187,8 +187,8 @@ impl<'a> Api<'a> {
config
.host(host)
.port(port)
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
}

View File

@@ -56,7 +56,7 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl ClientCredentials {
impl ClientCredentials<'_> {
fn is_existing_user(&self) -> bool {
self.user.ends_with("@zenith")
}
@@ -64,15 +64,15 @@ impl ClientCredentials {
async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<DatabaseInfo, LegacyAuthError> {
let mut url = auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("login", creds.user)
.append_pair("database", creds.dbname)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
@@ -103,7 +103,7 @@ async fn authenticate_proxy_client(
async fn handle_existing_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
) -> auth::Result<compute::NodeInfo> {
let psql_session_id = super::link::new_psql_session_id();
let md5_salt = rand::random();
@@ -136,7 +136,7 @@ async fn handle_existing_user(
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {

View File

@@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
@@ -87,8 +87,8 @@ impl<'a> Api<'a> {
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
}

View File

@@ -1,6 +1,7 @@
//! User credentials used in authentication.
use crate::error::UserFacingError;
use std::borrow::Cow;
use thiserror::Error;
use utils::pq_proto::StartupMessageParams;
@@ -27,51 +28,59 @@ impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub project: Option<String>,
pub struct ClientCredentials<'a> {
pub user: &'a str,
pub dbname: &'a str,
pub project: Option<Cow<'a, str>>,
}
impl ClientCredentials {
impl ClientCredentials<'_> {
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
}
}
impl ClientCredentials {
impl<'a> ClientCredentials<'a> {
pub fn parse(
mut options: StartupMessageParams,
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
// Some parameters are absolutely necessary, others not so much.
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let user = get_param("user")?;
let dbname = get_param("database")?;
let project_a = get_param("project").ok();
// Project name might be passed via PG's command-line options.
let project_a = params.options_raw().and_then(|options| {
for opt in options {
if let Some(value) = opt.strip_prefix("project=") {
return Some(Cow::Borrowed(value));
}
}
None
});
// Alternative project name is in fact a subdomain from SNI.
// NOTE: we do not consider SNI if `common_name` is missing.
let project_b = sni
.zip(common_name)
.map(|(sni, cn)| {
// TODO: what if SNI is present but just a common name?
subdomain_from_sni(sni, cn)
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
.ok_or_else(|| InconsistentSni(sni.into(), cn.into()))
.map(Cow::<'static, str>::Owned)
})
.transpose()?;
let project = match (project_a, project_b) {
// Invariant: if we have both project name variants, they should match.
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
(a, b) => a.or(b).map(|name| {
// Invariant: project name may not contain certain characters.
check_project_name(name).map_err(MalformedProjectName)
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a.into(), b.into()))),
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
false => Err(MalformedProjectName(name.into())),
true => Ok(name),
}),
}
.transpose()?;
@@ -84,12 +93,8 @@ impl ClientCredentials {
}
}
fn check_project_name(name: String) -> Result<String, String> {
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
Ok(name)
} else {
Err(name)
}
fn project_name_valid(name: &str) -> bool {
name.chars().all(|c| c.is_alphanumeric() || c == '-')
}
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
@@ -102,18 +107,14 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
mod tests {
use super::*;
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
}
#[test]
#[ignore = "TODO: fix how database is handled"]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = make_options([("user", "john_doe")]);
let options = StartupMessageParams::new([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
Ok(())
@@ -121,9 +122,9 @@ mod tests {
#[test]
fn parse_missing_project() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
@@ -133,12 +134,12 @@ mod tests {
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -148,13 +149,13 @@ mod tests {
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "bar"),
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -164,16 +165,16 @@ mod tests {
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "baz"),
("options", "project=baz"),
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -183,17 +184,17 @@ mod tests {
#[test]
fn parse_projects_different() {
let options = make_options([
let options = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"),
("project", "first"),
("options", "project=first"),
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
assert!(matches!(
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"),
ClientCredsParseError::InconsistentProjectNames(_, _)
));
}

View File

@@ -95,7 +95,7 @@ impl<'a> Session<'a> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
self.cancel_map
.0
.lock()

View File

@@ -1,9 +1,11 @@
use crate::{cancellation::CancelClosure, error::UserFacingError};
use futures::TryFutureExt;
use itertools::Itertools;
use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
use utils::pq_proto::StartupMessageParams;
#[derive(Debug, Error)]
pub enum ConnectionError {
@@ -110,7 +112,42 @@ pub struct PostgresConnection {
impl NodeInfo {
/// Connect to a corresponding compute node.
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
pub async fn connect(
mut self,
params: &StartupMessageParams,
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
if let Some(options) = params.options_raw() {
// We must drop all proxy-specific parameters.
#[allow(unstable_name_collisions)]
let options: String = options
.filter(|opt| !opt.starts_with("project="))
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
self.config.options(&options);
}
if let Some(app_name) = params.get("application_name") {
self.config.application_name(app_name);
}
if let Some(replication) = params.get("replication") {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.config.replication_mode(ReplicationMode::Physical);
}
"database" => {
self.config.replication_mode(ReplicationMode::Logical);
}
_other => {}
}
}
// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
let (socket_addr, mut stream) = self
.connect_raw()
.await

View File

@@ -1,6 +1,6 @@
use crate::auth;
use crate::cancellation::{self, CancelMap};
use crate::config::{ProxyConfig, TlsConfig};
use crate::config::{AuthUrls, ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -93,20 +93,21 @@ async fn handle_client(
None => return Ok(()), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds);
let client = Client::new(stream, creds, &params);
cancel_map
.with_session(|session| client.connect_to_db(config, session))
.with_session(|session| client.connect_to_db(&config.auth_urls, session))
.await
}
@@ -174,38 +175,57 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
/// Thin connection context.
struct Client<S> {
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::BackendType<auth::ClientCredentials>,
creds: auth::BackendType<auth::ClientCredentials<'a>>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
}
impl<S> Client<S> {
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
Self { stream, creds }
fn new(
stream: PqStream<S>,
creds: auth::BackendType<auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
) -> Self {
Self {
stream,
creds,
params,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,
config: &ProxyConfig,
urls: &AuthUrls,
session: cancellation::Session<'_>,
) -> anyhow::Result<()> {
let Self { mut stream, creds } = self;
let Self {
mut stream,
creds,
params,
} = self;
// Authenticate and connect to a compute node.
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
let auth = creds.authenticate(urls, &mut stream).await;
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let reported_auth_ok = node.reported_auth_ok;
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
let (db, cancel_closure) = node
.connect(params)
.or_else(|e| stream.throw_error(e))
.await?;
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
// Report authentication success if we haven't done this already.
if !node.reported_auth_ok {
if !reported_auth_ok {
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;