mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
Forward various connection params to compute nodes. (#2336)
Previously, proxy didn't forward auxiliary `options` parameter and other ones to the client's compute node, e.g. ``` $ psql "user=john host=localhost dbname=postgres options='-cgeqo=off'" postgres=# show geqo; ┌──────┐ │ geqo │ ├──────┤ │ on │ └──────┘ (1 row) ``` With this patch we now forward `options`, `application_name` and `replication`. Further reading: https://www.postgresql.org/docs/current/libpq-connect.html Fixes #1287.
This commit is contained in:
@@ -15,6 +15,7 @@ hashbrown = "0.12"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.12.1"
|
||||
hyper = "0.14"
|
||||
itertools = "0.10.3"
|
||||
once_cell = "1.13.0"
|
||||
md5 = "0.7.0"
|
||||
parking_lot = "0.12"
|
||||
|
||||
@@ -127,7 +127,7 @@ impl<T, E> BackendType<Result<T, E>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<ClientCredentials> {
|
||||
impl BackendType<ClientCredentials<'_>> {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
@@ -149,7 +149,7 @@ impl BackendType<ClientCredentials> {
|
||||
|
||||
// Finally we may finish the initialization of `creds`.
|
||||
// TODO: add missing type safety to ClientCredentials.
|
||||
creds.project = Some(payload.project);
|
||||
creds.project = Some(payload.project.into());
|
||||
|
||||
let mut config = match &self {
|
||||
Console(creds) => {
|
||||
|
||||
@@ -121,7 +121,7 @@ pub enum AuthInfo {
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
@@ -143,7 +143,7 @@ impl<'a> Api<'a> {
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project().expect("impossible"))
|
||||
.append_pair("role", &self.creds.user);
|
||||
.append_pair("role", self.creds.user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
@@ -187,8 +187,8 @@ impl<'a> Api<'a> {
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ enum ProxyAuthResponse {
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ClientCredentials<'_> {
|
||||
fn is_existing_user(&self) -> bool {
|
||||
self.user.ends_with("@zenith")
|
||||
}
|
||||
@@ -64,15 +64,15 @@ impl ClientCredentials {
|
||||
|
||||
async fn authenticate_proxy_client(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
md5_response: &str,
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<DatabaseInfo, LegacyAuthError> {
|
||||
let mut url = auth_endpoint.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("login", creds.user)
|
||||
.append_pair("database", creds.dbname)
|
||||
.append_pair("md5response", md5_response)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
@@ -103,7 +103,7 @@ async fn authenticate_proxy_client(
|
||||
async fn handle_existing_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
let psql_session_id = super::link::new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
@@ -136,7 +136,7 @@ async fn handle_existing_user(
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
auth_link_uri: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
if creds.is_existing_user() {
|
||||
|
||||
@@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
creds: &'a ClientCredentials<'a>,
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
@@ -87,8 +87,8 @@ impl<'a> Api<'a> {
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
.dbname(self.creds.dbname)
|
||||
.user(self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
use std::borrow::Cow;
|
||||
use thiserror::Error;
|
||||
use utils::pq_proto::StartupMessageParams;
|
||||
|
||||
@@ -27,51 +28,59 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
pub project: Option<String>,
|
||||
pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ClientCredentials<'_> {
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl<'a> ClientCredentials<'a> {
|
||||
pub fn parse(
|
||||
mut options: StartupMessageParams,
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
// Some parameters are absolutely necessary, others not so much.
|
||||
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user = get_param("user")?;
|
||||
let dbname = get_param("database")?;
|
||||
let project_a = get_param("project").ok();
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_a = params.options_raw().and_then(|options| {
|
||||
for opt in options {
|
||||
if let Some(value) = opt.strip_prefix("project=") {
|
||||
return Some(Cow::Borrowed(value));
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
// Alternative project name is in fact a subdomain from SNI.
|
||||
// NOTE: we do not consider SNI if `common_name` is missing.
|
||||
let project_b = sni
|
||||
.zip(common_name)
|
||||
.map(|(sni, cn)| {
|
||||
// TODO: what if SNI is present but just a common name?
|
||||
subdomain_from_sni(sni, cn)
|
||||
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
|
||||
.ok_or_else(|| InconsistentSni(sni.into(), cn.into()))
|
||||
.map(Cow::<'static, str>::Owned)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let project = match (project_a, project_b) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
|
||||
(a, b) => a.or(b).map(|name| {
|
||||
// Invariant: project name may not contain certain characters.
|
||||
check_project_name(name).map_err(MalformedProjectName)
|
||||
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a.into(), b.into()))),
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
|
||||
false => Err(MalformedProjectName(name.into())),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
@@ -84,12 +93,8 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_project_name(name: String) -> Result<String, String> {
|
||||
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
|
||||
Ok(name)
|
||||
} else {
|
||||
Err(name)
|
||||
}
|
||||
fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
@@ -102,18 +107,14 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
|
||||
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "TODO: fix how database is handled"]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = make_options([("user", "john_doe")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
@@ -121,9 +122,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -133,12 +134,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_sni() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -148,13 +149,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_options() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "bar"),
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -164,16 +165,16 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_projects_identical() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "baz"),
|
||||
("options", "project=baz"),
|
||||
]);
|
||||
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -183,17 +184,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_projects_different() {
|
||||
let options = make_options([
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "first"),
|
||||
("options", "project=first"),
|
||||
]);
|
||||
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
assert!(matches!(
|
||||
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
|
||||
ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"),
|
||||
ClientCredsParseError::InconsistentProjectNames(_, _)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ impl<'a> Session<'a> {
|
||||
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::{cancellation::CancelClosure, error::UserFacingError};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use std::{io, net::SocketAddr};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use utils::pq_proto::StartupMessageParams;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConnectionError {
|
||||
@@ -110,7 +112,42 @@ pub struct PostgresConnection {
|
||||
|
||||
impl NodeInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
if let Some(options) = params.options_raw() {
|
||||
// We must drop all proxy-specific parameters.
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = options
|
||||
.filter(|opt| !opt.starts_with("project="))
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
self.config.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.config.application_name(app_name);
|
||||
}
|
||||
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.config.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.config.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
|
||||
let (socket_addr, mut stream) = self
|
||||
.connect_raw()
|
||||
.await
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelMap};
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::config::{AuthUrls, ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
@@ -93,20 +93,21 @@ async fn handle_client(
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds);
|
||||
let client = Client::new(stream, creds, ¶ms);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(config, session))
|
||||
.with_session(|session| client.connect_to_db(&config.auth_urls, session))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -174,38 +175,57 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<S> {
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<auth::ClientCredentials>,
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
}
|
||||
|
||||
impl<S> Client<S> {
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
|
||||
Self { stream, creds }
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
creds: auth::BackendType<auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
urls: &AuthUrls,
|
||||
session: cancellation::Session<'_>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Self { mut stream, creds } = self;
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
} = self;
|
||||
|
||||
// Authenticate and connect to a compute node.
|
||||
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
|
||||
let auth = creds.authenticate(urls, &mut stream).await;
|
||||
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
|
||||
let reported_auth_ok = node.reported_auth_ok;
|
||||
|
||||
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
|
||||
let cancel_key_data = session.enable_cancellation(cancel_closure);
|
||||
let (db, cancel_closure) = node
|
||||
.connect(params)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
if !node.reported_auth_ok {
|
||||
if !reported_auth_ok {
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Reference in New Issue
Block a user