mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-02 13:00:37 +00:00
[proxy] Add the password hack authentication flow (#2095)
[proxy] Add the `password hack` authentication flow This lets us authenticate users which can use neither SNI (due to old libpq) nor connection string `options` (due to restrictions in other client libraries). Note: `PasswordHack` will accept passwords which are not encoded in base64 via the "password" field. The assumption is that most user passwords will be valid utf-8 strings, and the rest may still be passed via "password_".
This commit is contained in:
@@ -47,10 +47,12 @@ pub enum FeStartupPacket {
|
||||
StartupMessage {
|
||||
major_version: u32,
|
||||
minor_version: u32,
|
||||
params: HashMap<String, String>,
|
||||
params: StartupMessageParams,
|
||||
},
|
||||
}
|
||||
|
||||
pub type StartupMessageParams = HashMap<String, String>;
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct CancelKeyData {
|
||||
pub backend_pid: i32,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::DatabaseInfo;
|
||||
pub use backend::{BackendType, DatabaseInfo};
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
mod password_hack;
|
||||
use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
|
||||
@@ -29,9 +32,8 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
/// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`].
|
||||
#[error("Malformed password message")]
|
||||
MalformedPassword,
|
||||
#[error("Malformed password message: {0}")]
|
||||
MalformedPassword(&'static str),
|
||||
|
||||
/// Errors produced by [`crate::stream::PqStream`].
|
||||
#[error(transparent)]
|
||||
@@ -76,7 +78,7 @@ impl UserFacingError for AuthError {
|
||||
Console(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
MalformedPassword => self.to_string(),
|
||||
MalformedPassword(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
mod legacy_console;
|
||||
mod link;
|
||||
mod postgres;
|
||||
|
||||
pub mod console;
|
||||
|
||||
mod legacy_console;
|
||||
pub use legacy_console::{AuthError, AuthErrorImpl};
|
||||
|
||||
use super::ClientCredentials;
|
||||
use crate::{
|
||||
compute,
|
||||
config::{AuthBackendType, ProxyConfig},
|
||||
mgmt,
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute, config, mgmt,
|
||||
stream::PqStream,
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
@@ -78,32 +76,158 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: ClientCredentials,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
use AuthBackendType::*;
|
||||
match config.auth_backend {
|
||||
LegacyConsole => {
|
||||
legacy_console::handle_user(
|
||||
&config.auth_endpoint,
|
||||
&config.auth_link_uri,
|
||||
client,
|
||||
&creds,
|
||||
)
|
||||
.await
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ClientCredentials`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum BackendType<T> {
|
||||
/// Legacy Cloud API (V1) + link auth.
|
||||
LegacyConsole(T),
|
||||
/// Current Cloud API (V2).
|
||||
Console(T),
|
||||
/// Local mock of Cloud API (V2).
|
||||
Postgres(T),
|
||||
/// Authentication via a web browser.
|
||||
Link,
|
||||
}
|
||||
|
||||
impl<T> BackendType<T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<R> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
LegacyConsole(x) => LegacyConsole(f(x)),
|
||||
Console(x) => Console(f(x)),
|
||||
Postgres(x) => Postgres(f(x)),
|
||||
Link => Link,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> BackendType<Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<T>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
LegacyConsole(x) => x.map(LegacyConsole),
|
||||
Console(x) => x.map(Console),
|
||||
Postgres(x) => x.map(Postgres),
|
||||
Link => Ok(Link),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<ClientCredentials> {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
urls: &config::AuthUrls,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
use BackendType::*;
|
||||
|
||||
if let Console(creds) | Postgres(creds) = &mut self {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
if creds.project().is_none() {
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
// Finally we may finish the initialization of `creds`.
|
||||
// TODO: add missing type safety to ClientCredentials.
|
||||
creds.project = Some(payload.project);
|
||||
|
||||
let mut config = match &self {
|
||||
Console(creds) => {
|
||||
console::Api::new(&urls.auth_endpoint, creds)
|
||||
.wake_compute()
|
||||
.await?
|
||||
}
|
||||
Postgres(creds) => {
|
||||
postgres::Api::new(&urls.auth_endpoint, creds)
|
||||
.wake_compute()
|
||||
.await?
|
||||
}
|
||||
_ => unreachable!("see the patterns above"),
|
||||
};
|
||||
|
||||
// We should use a password from payload as well.
|
||||
config.password(payload.password);
|
||||
|
||||
return Ok(compute::NodeInfo {
|
||||
reported_auth_ok: false,
|
||||
config,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match self {
|
||||
LegacyConsole(creds) => {
|
||||
legacy_console::handle_user(
|
||||
&urls.auth_endpoint,
|
||||
&urls.auth_link_uri,
|
||||
&creds,
|
||||
client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Console(creds) => {
|
||||
console::Api::new(&urls.auth_endpoint, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Postgres(creds) => {
|
||||
postgres::Api::new(&urls.auth_endpoint, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link => link::handle_user(&urls.auth_link_uri, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_backend_type_map() {
|
||||
let values = [
|
||||
BackendType::LegacyConsole(0),
|
||||
BackendType::Console(0),
|
||||
BackendType::Postgres(0),
|
||||
BackendType::Link,
|
||||
];
|
||||
|
||||
for value in values {
|
||||
assert_eq!(value.map(|x| x), value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_type_transpose() {
|
||||
let values = [
|
||||
BackendType::LegacyConsole(Ok::<_, ()>(0)),
|
||||
BackendType::Console(Ok(0)),
|
||||
BackendType::Postgres(Ok(0)),
|
||||
BackendType::Link,
|
||||
];
|
||||
|
||||
for value in values {
|
||||
assert_eq!(value.map(Result::unwrap), value.transpose().unwrap());
|
||||
}
|
||||
Console => {
|
||||
console::Api::new(&config.auth_endpoint, &creds)?
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Postgres => {
|
||||
postgres::Api::new(&config.auth_endpoint, &creds)?
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Link => link::handle_user(&config.auth_link_uri, client).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
//! Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials, DatabaseInfo},
|
||||
compute,
|
||||
error::UserFacingError,
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute::{self, ComputeConnCfg},
|
||||
error::{io_error, UserFacingError},
|
||||
scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{future::Future, io};
|
||||
use std::future::Future;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
|
||||
|
||||
@@ -84,8 +83,8 @@ pub(super) struct Api<'a> {
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
|
||||
Ok(Self { endpoint, creds })
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
|
||||
Self { endpoint, creds }
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
@@ -100,7 +99,7 @@ impl<'a> Api<'a> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project_name.as_ref()?)
|
||||
.append_pair("project", self.creds.project().expect("impossible"))
|
||||
.append_pair("role", &self.creds.user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
@@ -120,11 +119,11 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(&self) -> Result<DatabaseInfo> {
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_wake_compute");
|
||||
let project_name = self.creds.project_name.as_ref()?;
|
||||
url.query_pairs_mut().append_pair("project", project_name);
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.creds.project().expect("impossible"));
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
@@ -137,16 +136,20 @@ impl<'a> Api<'a> {
|
||||
let response: GetWakeComputeResponse =
|
||||
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
|
||||
|
||||
let (host, port) = parse_host_port(&response.address)
|
||||
.ok_or(ConsoleAuthError::BadComputeAddress(response.address))?;
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&response.address) {
|
||||
None => return Err(ConsoleAuthError::BadComputeAddress(response.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
Ok(DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: self.creds.dbname.to_owned(),
|
||||
user: self.creds.user.to_owned(),
|
||||
password: None,
|
||||
})
|
||||
let mut config = ComputeConnCfg::new();
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +163,7 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
|
||||
) -> auth::Result<compute::NodeInfo>
|
||||
where
|
||||
GetAuthInfo: Future<Output = Result<AuthInfo>>,
|
||||
WakeCompute: Future<Output = Result<DatabaseInfo>>,
|
||||
WakeCompute: Future<Output = Result<ComputeConnCfg>>,
|
||||
{
|
||||
let auth_info = get_auth_info(endpoint).await?;
|
||||
|
||||
@@ -179,48 +182,18 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
let mut config = wake_compute(endpoint).await?;
|
||||
if let Some(keys) = scram_keys {
|
||||
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: wake_compute(endpoint).await?,
|
||||
scram_keys,
|
||||
reported_auth_ok: false,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub(super) fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(String, u16)> {
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host.to_owned(), port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
use utils::pq_proto::BeMessage as Be;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
@@ -76,6 +76,12 @@ enum ProxyAuthResponse {
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
fn is_existing_user(&self) -> bool {
|
||||
self.user.ends_with("@zenith")
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_proxy_client(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
@@ -100,7 +106,7 @@ async fn authenticate_proxy_client(
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
println!("got auth info: {:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
let db_info = match auth_info {
|
||||
@@ -128,7 +134,9 @@ async fn handle_existing_user(
|
||||
|
||||
// Read client's password hash
|
||||
let msg = client.read_password_message().await?;
|
||||
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?;
|
||||
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword(
|
||||
"the password should be a valid null-terminated utf-8 string",
|
||||
))?;
|
||||
|
||||
let db_info = authenticate_proxy_client(
|
||||
auth_endpoint,
|
||||
@@ -139,21 +147,17 @@ async fn handle_existing_user(
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
reported_auth_ok: false,
|
||||
config: db_info.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
auth_link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
if creds.is_existing_user() {
|
||||
handle_existing_user(auth_endpoint, client, creds).await
|
||||
@@ -201,4 +205,24 @@ mod tests {
|
||||
.unwrap();
|
||||
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ pub async fn handle_user(
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
reported_auth_ok: true,
|
||||
config: db_info.into(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::console::{self, io_error, AuthInfo, Result},
|
||||
ClientCredentials, DatabaseInfo,
|
||||
backend::console::{self, AuthInfo, Result},
|
||||
ClientCredentials,
|
||||
},
|
||||
compute, scram,
|
||||
compute::{self, ComputeConnCfg},
|
||||
error::io_error,
|
||||
scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
};
|
||||
@@ -20,8 +22,8 @@ pub(super) struct Api<'a> {
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
|
||||
Ok(Self { endpoint, creds })
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
|
||||
Self { endpoint, creds }
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
@@ -56,7 +58,10 @@ impl<'a> Api<'a> {
|
||||
|
||||
// We shouldn't get more than one row anyway.
|
||||
[row, ..] => {
|
||||
let entry = row.try_get(0).map_err(io_error)?;
|
||||
let entry = row
|
||||
.try_get("rolpassword")
|
||||
.map_err(|e| io_error(format!("failed to read user's password: {e}")))?;
|
||||
|
||||
scram::ServerSecret::parse(entry)
|
||||
.map(AuthInfo::Scram)
|
||||
.or_else(|| {
|
||||
@@ -75,14 +80,14 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
/// We don't need to wake anything locally, so we just return the connection info.
|
||||
async fn wake_compute(&self) -> Result<DatabaseInfo> {
|
||||
Ok(DatabaseInfo {
|
||||
// TODO: handle that near CLI params parsing
|
||||
host: self.endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
port: self.endpoint.port().unwrap_or(5432),
|
||||
dbname: self.creds.dbname.to_owned(),
|
||||
user: self.creds.user.to_owned(),
|
||||
password: None,
|
||||
})
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
|
||||
let mut config = ComputeConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.dbname(&self.creds.dbname)
|
||||
.user(&self.creds.user);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,39 +1,25 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::compute;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::StartupMessageParams;
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet.")]
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
#[error(
|
||||
"Project name is not specified. \
|
||||
EITHER please upgrade the postgres client library (libpq) for SNI support \
|
||||
OR pass the project name as a parameter: '&options=project%3D<project-name>'."
|
||||
)]
|
||||
MissingSNIAndProjectName,
|
||||
|
||||
#[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")]
|
||||
InconsistentProjectNameAndSNI(String, String),
|
||||
|
||||
#[error("Common name is not set.")]
|
||||
CommonNameNotSet,
|
||||
InconsistentProjectNames(String, String),
|
||||
|
||||
#[error(
|
||||
"SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \
|
||||
SNI should be formatted as '<project-name>.<common-name>'."
|
||||
SNI should be formatted as '<project-name>.{0}'."
|
||||
)]
|
||||
InconsistentCommonNameAndSNI(String, String),
|
||||
InconsistentSni(String, String),
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")]
|
||||
ProjectNameContainsIllegalChars(String),
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(String),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
@@ -44,286 +30,171 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
pub project_name: Result<String, ClientCredsParseError>,
|
||||
pub project: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
pub fn is_existing_user(&self) -> bool {
|
||||
// This logic will likely change in the future.
|
||||
self.user.ends_with("@zenith")
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
pub fn parse(
|
||||
mut options: HashMap<String, String>,
|
||||
sni_data: Option<&str>,
|
||||
mut options: StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
let mut get_param = |key| {
|
||||
options
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
// Some parameters are absolutely necessary, others not so much.
|
||||
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let user = get_param("user")?;
|
||||
let dbname = get_param("database")?;
|
||||
let project_name = get_param("project").ok();
|
||||
let project_name = get_project_name(sni_data, common_name, project_name.as_deref());
|
||||
let project_a = get_param("project").ok();
|
||||
|
||||
// Alternative project name is in fact a subdomain from SNI.
|
||||
// NOTE: we do not consider SNI if `common_name` is missing.
|
||||
let project_b = sni
|
||||
.zip(common_name)
|
||||
.map(|(sni, cn)| {
|
||||
// TODO: what if SNI is present but just a common name?
|
||||
subdomain_from_sni(sni, cn)
|
||||
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let project = match (project_a, project_b) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
|
||||
(a, b) => a.or(b).map(|name| {
|
||||
// Invariant: project name may not contain certain characters.
|
||||
check_project_name(name).map_err(MalformedProjectName)
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
dbname,
|
||||
project_name,
|
||||
project,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
// This method is just a convenient facade for `handle_user`
|
||||
super::backend::handle_user(config, client, self).await
|
||||
fn check_project_name(name: String) -> Result<String, String> {
|
||||
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
|
||||
Ok(name)
|
||||
} else {
|
||||
Err(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Inferring project name from sni_data.
|
||||
fn project_name_from_sni_data(
|
||||
sni_data: &str,
|
||||
common_name: &str,
|
||||
) -> Result<String, ClientCredsParseError> {
|
||||
let common_name_with_dot = format!(".{common_name}");
|
||||
// check that ".{common_name_with_dot}" is the actual suffix in sni_data
|
||||
if !sni_data.ends_with(&common_name_with_dot) {
|
||||
return Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data.to_string(),
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
sni.strip_suffix(common_name)?
|
||||
.strip_suffix('.')
|
||||
.map(str::to_owned)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
|
||||
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "TODO: fix how database is handled"]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = make_options([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_sni() -> anyhow::Result<()> {
|
||||
let options = make_options([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_options() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "bar"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_projects_identical() -> anyhow::Result<()> {
|
||||
let options = make_options([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "baz"),
|
||||
]);
|
||||
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(options, sni, common_name)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_projects_different() {
|
||||
let options = make_options([
|
||||
("user", "john_doe"),
|
||||
("database", "world"),
|
||||
("project", "first"),
|
||||
]);
|
||||
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
assert!(matches!(
|
||||
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
|
||||
ClientCredsParseError::InconsistentProjectNames(_, _)
|
||||
));
|
||||
}
|
||||
// return sni_data without the common name suffix.
|
||||
Ok(sni_data
|
||||
.strip_suffix(&common_name_with_dot)
|
||||
.unwrap()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_for_project_name_from_sni_data {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn passing() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
project_name_from_sni_data(&sni_data, common_name),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_common_name_and_sni_data() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let wrong_suffix = "wrongtest.me";
|
||||
assert_eq!(common_name.len(), wrong_suffix.len());
|
||||
let wrong_common_name = format!("wrong{wrong_suffix}");
|
||||
let sni_data = format!("{target_project_name}.{wrong_common_name}");
|
||||
assert_eq!(
|
||||
project_name_from_sni_data(&sni_data, common_name),
|
||||
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine project name from SNI or from project_name parameter from options argument.
|
||||
fn get_project_name(
|
||||
sni_data: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
project_name: Option<&str>,
|
||||
) -> Result<String, ClientCredsParseError> {
|
||||
// determine the project name from sni_data if it exists, otherwise from project_name.
|
||||
let ret = match sni_data {
|
||||
Some(sni_data) => {
|
||||
let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?;
|
||||
let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?;
|
||||
// check invariant: project name from options and from sni should match
|
||||
if let Some(project_name) = &project_name {
|
||||
if !project_name_from_sni.eq(project_name) {
|
||||
return Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
|
||||
project_name_from_sni,
|
||||
project_name.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
project_name_from_sni
|
||||
}
|
||||
None => project_name
|
||||
.ok_or(ClientCredsParseError::MissingSNIAndProjectName)?
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
// check formatting invariant: project name must contain only alphanumeric characters and hyphens.
|
||||
if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') {
|
||||
return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret));
|
||||
}
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_for_project_name_only {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn passing_from_sni_data_only() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), None),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_project_name_contains_illegal_chars_from_sni_data_only() {
|
||||
let project_name_prefix = "my-project";
|
||||
let project_name_suffix = "123";
|
||||
let common_name = "localtest.me";
|
||||
|
||||
for illegal_char_id in 0..256 {
|
||||
let illegal_char = char::from_u32(illegal_char_id).unwrap();
|
||||
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
|
||||
&& illegal_char.to_string().len() == 1
|
||||
{
|
||||
let target_project_name =
|
||||
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), None),
|
||||
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
|
||||
target_project_name
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn passing_from_project_name_only() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_names = [Some("localtest.me"), None];
|
||||
for common_name in common_names {
|
||||
assert_eq!(
|
||||
get_project_name(None, common_name, Some(target_project_name)),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_project_name_contains_illegal_chars_from_project_name_only() {
|
||||
let project_name_prefix = "my-project";
|
||||
let project_name_suffix = "123";
|
||||
let common_names = [Some("localtest.me"), None];
|
||||
|
||||
for common_name in common_names {
|
||||
for illegal_char_id in 0..256 {
|
||||
let illegal_char: char = char::from_u32(illegal_char_id).unwrap();
|
||||
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
|
||||
&& illegal_char.to_string().len() == 1
|
||||
{
|
||||
let target_project_name =
|
||||
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
|
||||
assert_eq!(
|
||||
get_project_name(None, common_name, Some(&target_project_name)),
|
||||
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
|
||||
target_project_name
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn passing_from_sni_data_and_project_name() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(
|
||||
Some(&sni_data),
|
||||
Some(common_name),
|
||||
Some(target_project_name)
|
||||
),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_project_name_and_sni() {
|
||||
let project_name_param = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{wrong_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)),
|
||||
Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
|
||||
wrong_project_name.to_string(),
|
||||
project_name_param.to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_common_name_not_set() {
|
||||
let target_project_name = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_datas = [
|
||||
Some(format!("{wrong_project_name}.{common_name}")),
|
||||
Some(format!("{target_project_name}.{common_name}")),
|
||||
];
|
||||
let project_names = [None, Some(target_project_name)];
|
||||
for sni_data in sni_datas {
|
||||
for project_name_param in project_names {
|
||||
assert_eq!(
|
||||
get_project_name(sni_data.as_deref(), None, project_name_param),
|
||||
Err(ClientCredsParseError::CommonNameNotSet)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_common_name_and_sni_data() {
|
||||
let target_project_name = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let wrong_suffix = "wrongtest.me";
|
||||
assert_eq!(common_name.len(), wrong_suffix.len());
|
||||
let wrong_common_name = format!("wrong{wrong_suffix}");
|
||||
let sni_datas = [
|
||||
Some(format!("{wrong_project_name}.{wrong_common_name}")),
|
||||
Some(format!("{target_project_name}.{wrong_common_name}")),
|
||||
];
|
||||
let project_names = [None, Some(target_project_name)];
|
||||
for project_name_param in project_names {
|
||||
for sni_data in &sni_datas {
|
||||
assert_eq!(
|
||||
get_project_name(sni_data.as_deref(), Some(common_name), project_name_param),
|
||||
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data.clone().unwrap().to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::AuthErrorImpl;
|
||||
use crate::stream::PqStream;
|
||||
use crate::{sasl, scram};
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{sasl, scram, stream::PqStream};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
@@ -27,6 +26,17 @@ impl AuthMethod for Scram<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
@@ -57,13 +67,34 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
// The so-called "password" should contain a base64-encoded json.
|
||||
// We will use it later to route the client to their project.
|
||||
let bytes = base64::decode(password)
|
||||
.map_err(|_| AuthErrorImpl::MalformedPassword("bad encoding"))?;
|
||||
|
||||
let payload = serde_json::from_slice(&bytes)
|
||||
.map_err(|_| AuthErrorImpl::MalformedPassword("invalid payload"))?;
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
|
||||
102
proxy/src/auth/password_hack.rs
Normal file
102
proxy/src/auth/password_hack.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Payload for ad hoc authentication method for clients that don't support SNI.
|
||||
//! See the `impl` for [`super::backend::BackendType<ClientCredentials>`].
|
||||
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
|
||||
use serde::{de, Deserialize, Deserializer};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Password {
|
||||
/// A regular string for utf-8 encoded passwords.
|
||||
Simple { password: String },
|
||||
|
||||
/// Password is base64-encoded because it may contain arbitrary byte sequences.
|
||||
Encoded {
|
||||
#[serde(rename = "password_", deserialize_with = "deserialize_base64")]
|
||||
password: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Password {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
Password::Simple { password } => password.as_ref(),
|
||||
Password::Encoded { password } => password.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PasswordHackPayload {
|
||||
pub project: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub password: Password,
|
||||
}
|
||||
|
||||
fn deserialize_base64<'a, D: Deserializer<'a>>(des: D) -> Result<Vec<u8>, D::Error> {
|
||||
// It's very tempting to replace this with
|
||||
//
|
||||
// ```
|
||||
// let base64: &str = Deserialize::deserialize(des)?;
|
||||
// base64::decode(base64).map_err(serde::de::Error::custom)
|
||||
// ```
|
||||
//
|
||||
// Unfortunately, we can't always deserialize into `&str`, so we'd
|
||||
// have to use an allocating `String` instead. Thus, visitor is better.
|
||||
struct Visitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for Visitor {
|
||||
type Value = Vec<u8>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||
base64::decode(v).map_err(de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
des.deserialize_str(Visitor)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::rstest;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_password() -> anyhow::Result<()> {
|
||||
let password: Password = serde_json::from_value(json!({
|
||||
"password": "foo",
|
||||
}))?;
|
||||
assert_eq!(password.as_ref(), "foo".as_bytes());
|
||||
|
||||
let password: Password = serde_json::from_value(json!({
|
||||
"password_": base64::encode("foo"),
|
||||
}))?;
|
||||
assert_eq!(password.as_ref(), "foo".as_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case("password", str::to_owned)]
|
||||
#[case("password_", base64::encode)]
|
||||
fn parse(#[case] key: &str, #[case] encode: fn(&'static str) -> String) -> anyhow::Result<()> {
|
||||
let (password, project) = ("password", "pie-in-the-sky");
|
||||
let payload = json!({
|
||||
"project": project,
|
||||
key: encode(password),
|
||||
});
|
||||
|
||||
let payload: PasswordHackPayload = serde_json::from_value(payload)?;
|
||||
assert_eq!(payload.password.as_ref(), password.as_bytes());
|
||||
assert_eq!(payload.project, project);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
use crate::auth::DatabaseInfo;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::error::UserFacingError;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use crate::{cancellation::CancelClosure, error::UserFacingError};
|
||||
use futures::TryFutureExt;
|
||||
use std::{io, net::SocketAddr};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
@@ -21,44 +19,96 @@ pub enum ConnectionError {
|
||||
FailedToFetchPgVersion,
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnectionError {}
|
||||
|
||||
/// PostgreSQL version as [`String`].
|
||||
pub type Version = String;
|
||||
impl UserFacingError for ConnectionError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ConnectionError::*;
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => err.message().to_string(),
|
||||
None => err.to_string(),
|
||||
},
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
/// Compute node connection params.
|
||||
pub type ComputeConnCfg = tokio_postgres::Config;
|
||||
|
||||
/// Various compute node info for establishing connection etc.
|
||||
pub struct NodeInfo {
|
||||
pub db_info: DatabaseInfo,
|
||||
pub scram_keys: Option<ScramKeys>,
|
||||
/// Did we send [`utils::pq_proto::BeMessage::AuthenticationOk`]?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Compute node connection params.
|
||||
pub config: tokio_postgres::Config,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
|
||||
let host_port = (self.db_info.host.as_str(), self.db_info.port);
|
||||
let socket = TcpStream::connect(host_port).await?;
|
||||
let socket_addr = socket.peer_addr()?;
|
||||
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
||||
use tokio_postgres::config::Host;
|
||||
|
||||
Ok((socket_addr, socket))
|
||||
let connect_once = |host, port| {
|
||||
TcpStream::connect((host, port)).and_then(|socket| async {
|
||||
let socket_addr = socket.peer_addr()?;
|
||||
// This prevents load balancer from severing the connection.
|
||||
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
||||
Ok((socket_addr, socket))
|
||||
})
|
||||
};
|
||||
|
||||
// We can't reuse connection establishing logic from `tokio_postgres` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let mut connection_error = None;
|
||||
let ports = self.config.get_ports();
|
||||
for (i, host) in self.config.get_hosts().iter().enumerate() {
|
||||
let port = ports.get(i).or_else(|| ports.get(0)).unwrap_or(&5432);
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
Host::Unix(_) => continue, // unix sockets are not welcome here
|
||||
};
|
||||
|
||||
// TODO: maybe we should add a timeout.
|
||||
match connect_once(host, *port).await {
|
||||
Ok(socket) => return Ok(socket),
|
||||
Err(err) => {
|
||||
// We can't throw an error here, as there might be more hosts to try.
|
||||
println!("failed to connect to compute `{host}:{port}`: {err}");
|
||||
connection_error = Some(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(connection_error.unwrap_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("couldn't connect: bad compute config: {:?}", self.config),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: TcpStream,
|
||||
/// PostgreSQL version of this instance.
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> {
|
||||
let (socket_addr, mut socket) = self
|
||||
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
|
||||
let (socket_addr, mut stream) = self
|
||||
.connect_raw()
|
||||
.await
|
||||
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
|
||||
|
||||
let mut config = tokio_postgres::Config::from(self.db_info);
|
||||
if let Some(scram_keys) = self.scram_keys {
|
||||
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys));
|
||||
}
|
||||
|
||||
// TODO: establish a secure connection to the DB
|
||||
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
|
||||
let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?;
|
||||
let version = conn
|
||||
.parameter("server_version")
|
||||
.ok_or(ConnectionError::FailedToFetchPgVersion)?
|
||||
@@ -66,6 +116,8 @@ impl NodeInfo {
|
||||
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
|
||||
Ok((socket, version, cancel_closure))
|
||||
let db = PostgresConnection { stream, version };
|
||||
|
||||
Ok((db, cancel_closure))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,16 @@
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{auth, url::ApiUrl};
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthBackendType {
|
||||
/// Legacy Cloud API (V1).
|
||||
LegacyConsole,
|
||||
/// Authentication via a web browser.
|
||||
Link,
|
||||
/// Current Cloud API (V2).
|
||||
Console,
|
||||
/// Local mock of Cloud API (V2).
|
||||
Postgres,
|
||||
}
|
||||
|
||||
impl FromStr for AuthBackendType {
|
||||
impl FromStr for auth::BackendType<()> {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
use AuthBackendType::*;
|
||||
use auth::BackendType::*;
|
||||
Ok(match s {
|
||||
"legacy" => LegacyConsole,
|
||||
"console" => Console,
|
||||
"postgres" => Postgres,
|
||||
"legacy" => LegacyConsole(()),
|
||||
"console" => Console(()),
|
||||
"postgres" => Postgres(()),
|
||||
"link" => Link,
|
||||
_ => bail!("Invalid option `{s}` for auth method"),
|
||||
})
|
||||
@@ -31,7 +19,11 @@ impl FromStr for AuthBackendType {
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: AuthBackendType,
|
||||
pub auth_backend: auth::BackendType<()>,
|
||||
pub auth_urls: AuthUrls,
|
||||
}
|
||||
|
||||
pub struct AuthUrls {
|
||||
pub auth_endpoint: ApiUrl,
|
||||
pub auth_link_uri: ApiUrl,
|
||||
}
|
||||
@@ -87,10 +79,8 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let almost_common_name = pem.parse_x509()?.tbs_certificate.subject.to_string();
|
||||
let expected_prefix = "CN=*.";
|
||||
let common_name = almost_common_name.strip_prefix(expected_prefix);
|
||||
common_name.map(str::to_string)
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
};
|
||||
|
||||
Ok(TlsConfig {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::io;
|
||||
|
||||
/// Marks errors that may be safely shown to a client.
|
||||
/// This trait can be seen as a specialized version of [`ToString`].
|
||||
///
|
||||
@@ -15,3 +17,8 @@ pub trait UserFacingError: ToString {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
|
||||
@@ -118,11 +118,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
|
||||
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
|
||||
|
||||
let auth_urls = config::AuthUrls {
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
};
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?,
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
auth_urls,
|
||||
}));
|
||||
|
||||
println!("Version: {GIT_VERSION}");
|
||||
|
||||
@@ -82,11 +82,22 @@ async fn handle_client(
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
|
||||
let (mut stream, params) = match handshake(stream, tls, cancel_map).await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(config, session))
|
||||
@@ -101,12 +112,10 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
let common_name = tls.and_then(|cfg| cfg.common_name.as_deref());
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
@@ -147,18 +156,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
// Get SNI info when available
|
||||
let sni_data = match stream.get_ref() {
|
||||
Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Construct credentials
|
||||
let creds =
|
||||
auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name);
|
||||
let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?;
|
||||
|
||||
break Ok(Some((stream, creds)));
|
||||
break Ok(Some((stream, params)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
@@ -174,12 +172,12 @@ struct Client<S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::ClientCredentials,
|
||||
creds: auth::BackendType<auth::ClientCredentials>,
|
||||
}
|
||||
|
||||
impl<S> Client<S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(stream: PqStream<S>, creds: auth::ClientCredentials) -> Self {
|
||||
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
|
||||
Self { stream, creds }
|
||||
}
|
||||
}
|
||||
@@ -194,16 +192,22 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
|
||||
let Self { mut stream, creds } = self;
|
||||
|
||||
// Authenticate and connect to a compute node.
|
||||
let auth = creds.authenticate(config, &mut stream).await;
|
||||
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
|
||||
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
|
||||
|
||||
let (db, version, cancel_closure) =
|
||||
node.connect().or_else(|e| stream.throw_error(e)).await?;
|
||||
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
|
||||
let cancel_key_data = session.enable_cancellation(cancel_closure);
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
if !node.reported_auth_ok {
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&version),
|
||||
BeParameterStatusMessage::ServerVersion(&db.version),
|
||||
))?
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
@@ -217,7 +221,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
|
||||
}
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
let mut db = MetricsStream::new(db, inc_proxied);
|
||||
let mut db = MetricsStream::new(db.stream, inc_proxied);
|
||||
let mut client = MetricsStream::new(stream.into_inner(), inc_proxied);
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
|
||||
@@ -279,9 +283,13 @@ mod tests {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert], key)?;
|
||||
.with_single_cert(vec![cert], key)?
|
||||
.into();
|
||||
|
||||
config.into()
|
||||
TlsConfig {
|
||||
config,
|
||||
common_name: Some(common_name.to_string()),
|
||||
}
|
||||
};
|
||||
|
||||
let client_config = {
|
||||
@@ -297,11 +305,6 @@ mod tests {
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
|
||||
let tls_config = TlsConfig {
|
||||
config: tls_config,
|
||||
common_name: Some(common_name.to_string()),
|
||||
};
|
||||
|
||||
Ok((client_config, tls_config))
|
||||
}
|
||||
|
||||
@@ -357,7 +360,7 @@ mod tests {
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map)
|
||||
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
|
||||
@@ -436,32 +439,6 @@ mod tests {
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.ssl_mode(SslMode::Disable)
|
||||
.connect_raw(server, NoTls)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
// TODO: this is ugly, but `format!` won't allow us to extract fmt string
|
||||
assert!(client_err.to_string().contains("missing in startup packet"));
|
||||
|
||||
let server_err = proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
assert!(client_err.to_string().contains(&server_err.to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn keepalive_is_inherited() -> anyhow::Result<()> {
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
@@ -145,6 +145,14 @@ impl<S> Stream<S> {
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
|
||||
/// Return SNI hostname when it's available.
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls } => tls.get_ref().1.sni_hostname(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
||||
@@ -1,8 +1,34 @@
|
||||
import pytest
|
||||
import json
|
||||
import base64
|
||||
|
||||
|
||||
def test_proxy_select_1(static_proxy):
|
||||
static_proxy.safe_psql("select 1;", options="project=generic-project-name")
|
||||
static_proxy.safe_psql('select 1', options='project=generic-project-name')
|
||||
|
||||
|
||||
def test_password_hack(static_proxy):
|
||||
user = 'borat'
|
||||
password = 'password'
|
||||
static_proxy.safe_psql(f"create role {user} with login password '{password}'",
|
||||
options='project=irrelevant')
|
||||
|
||||
def encode(s: str) -> str:
|
||||
return base64.b64encode(s.encode('utf-8')).decode('utf-8')
|
||||
|
||||
magic = encode(json.dumps({
|
||||
'project': 'irrelevant',
|
||||
'password': password,
|
||||
}))
|
||||
|
||||
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
|
||||
|
||||
magic = encode(json.dumps({
|
||||
'project': 'irrelevant',
|
||||
'password_': encode(password),
|
||||
}))
|
||||
|
||||
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
|
||||
|
||||
|
||||
# Pass extra options to the server.
|
||||
@@ -11,8 +37,8 @@ def test_proxy_select_1(static_proxy):
|
||||
# See https://github.com/neondatabase/neon/issues/1287
|
||||
@pytest.mark.xfail
|
||||
def test_proxy_options(static_proxy):
|
||||
with static_proxy.connect(options="-cproxytest.option=value") as conn:
|
||||
with static_proxy.connect(options='-cproxytest.option=value') as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW proxytest.option;")
|
||||
cur.execute('SHOW proxytest.option')
|
||||
value = cur.fetchall()[0][0]
|
||||
assert value == 'value'
|
||||
|
||||
@@ -30,7 +30,7 @@ from dataclasses import dataclass
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast, Union, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
||||
from typing_extensions import Literal
|
||||
|
||||
import requests
|
||||
@@ -280,20 +280,18 @@ class PgProtocol:
|
||||
return str(make_dsn(**self.conn_options(**kwargs)))
|
||||
|
||||
def conn_options(self, **kwargs):
|
||||
conn_options = self.default_options.copy()
|
||||
result = self.default_options.copy()
|
||||
if 'dsn' in kwargs:
|
||||
conn_options.update(parse_dsn(kwargs['dsn']))
|
||||
conn_options.update(kwargs)
|
||||
result.update(parse_dsn(kwargs['dsn']))
|
||||
result.update(kwargs)
|
||||
|
||||
# Individual statement timeout in seconds. 2 minutes should be
|
||||
# enough for our tests, but if you need a longer, you can
|
||||
# change it by calling "SET statement_timeout" after
|
||||
# connecting.
|
||||
if 'options' in conn_options:
|
||||
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
|
||||
else:
|
||||
conn_options['options'] = "-cstatement_timeout=120s"
|
||||
return conn_options
|
||||
options = result.get('options', '')
|
||||
result['options'] = f'-cstatement_timeout=120s {options}'
|
||||
return result
|
||||
|
||||
# autocommit=True here by default because that's what we need most of the time
|
||||
def connect(self, autocommit=True, **kwargs) -> PgConnection:
|
||||
@@ -1514,29 +1512,25 @@ def remote_pg(test_output_dir: Path) -> Iterator[RemotePostgres]:
|
||||
|
||||
|
||||
class NeonProxy(PgProtocol):
|
||||
def __init__(self, port: int, pg_port: int):
|
||||
super().__init__(host="127.0.0.1",
|
||||
user="proxy_user",
|
||||
password="pytest2",
|
||||
port=port,
|
||||
dbname='postgres')
|
||||
self.http_port = 7001
|
||||
self.host = "127.0.0.1"
|
||||
self.port = port
|
||||
self.pg_port = pg_port
|
||||
def __init__(self, proxy_port: int, http_port: int, auth_endpoint: str):
|
||||
super().__init__(dsn=auth_endpoint, port=proxy_port)
|
||||
self.host = '127.0.0.1'
|
||||
self.http_port = http_port
|
||||
self.proxy_port = proxy_port
|
||||
self.auth_endpoint = auth_endpoint
|
||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
def start(self) -> None:
|
||||
assert self._popen is None
|
||||
|
||||
# Start proxy
|
||||
bin_proxy = os.path.join(str(neon_binpath), 'proxy')
|
||||
args = [bin_proxy]
|
||||
args.extend(["--http", f"{self.host}:{self.http_port}"])
|
||||
args.extend(["--proxy", f"{self.host}:{self.port}"])
|
||||
args.extend(["--auth-backend", "postgres"])
|
||||
args.extend(
|
||||
["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"])
|
||||
args = [
|
||||
os.path.join(str(neon_binpath), 'proxy'),
|
||||
*["--http", f"{self.host}:{self.http_port}"],
|
||||
*["--proxy", f"{self.host}:{self.proxy_port}"],
|
||||
*["--auth-backend", "postgres"],
|
||||
*["--auth-endpoint", self.auth_endpoint],
|
||||
]
|
||||
self._popen = subprocess.Popen(args)
|
||||
self._wait_until_ready()
|
||||
|
||||
@@ -1557,13 +1551,21 @@ class NeonProxy(PgProtocol):
|
||||
@pytest.fixture(scope='function')
|
||||
def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
|
||||
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
|
||||
|
||||
port = port_distributor.get_port()
|
||||
pg_port = vanilla_pg.default_options['port']
|
||||
with NeonProxy(port, pg_port) as proxy:
|
||||
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
|
||||
|
||||
port = vanilla_pg.default_options['port']
|
||||
host = vanilla_pg.default_options['host']
|
||||
dbname = vanilla_pg.default_options['dbname']
|
||||
auth_endpoint = f'postgres://proxy:password@{host}:{port}/{dbname}'
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
|
||||
with NeonProxy(proxy_port=proxy_port, http_port=http_port,
|
||||
auth_endpoint=auth_endpoint) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
Reference in New Issue
Block a user