mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-01 04:20:39 +00:00
[proxy] Rework wire format of the password hack and some errors (#2236)
The new format has a few benefits: it's shorter, simpler and human-readable as well. We don't use base64 anymore, since url encoding got us covered. We also show a better error in case we couldn't parse the payload; the users should know it's all about passing the correct project name.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2269,6 +2269,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bstr",
|
||||
"bytes",
|
||||
"clap 3.2.12",
|
||||
"futures",
|
||||
|
||||
@@ -7,6 +7,7 @@ edition = "2021"
|
||||
anyhow = "1.0"
|
||||
async-trait = "0.1"
|
||||
base64 = "0.13.0"
|
||||
bstr = "0.2.17"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
clap = "3.0"
|
||||
futures = "0.3.13"
|
||||
|
||||
@@ -12,7 +12,7 @@ use password_hack::PasswordHackPayload;
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
|
||||
use crate::{error::UserFacingError, waiters};
|
||||
use crate::error::UserFacingError;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -22,51 +22,54 @@ pub type Result<T> = std::result::Result<T, AuthError>;
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
// This will be dropped in the future.
|
||||
#[error(transparent)]
|
||||
Console(#[from] backend::AuthError),
|
||||
Legacy(#[from] backend::LegacyAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] backend::console::ConsoleAuthError),
|
||||
Link(#[from] backend::LinkAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] backend::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] backend::WakeComputeError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
#[error("Unsupported authentication method: {0}")]
|
||||
BadAuthMethod(Box<str>),
|
||||
|
||||
#[error("Malformed password message: {0}")]
|
||||
MalformedPassword(&'static str),
|
||||
|
||||
/// Errors produced by [`crate::stream::PqStream`].
|
||||
#[error(
|
||||
"Project name is not specified. \
|
||||
Either please upgrade the postgres client library (libpq) for SNI support \
|
||||
or pass the project name as a parameter: '&options=project%3D<project-name>'. \
|
||||
See more at https://neon.tech/sni"
|
||||
)]
|
||||
MissingProjectName,
|
||||
|
||||
/// Errors produced by e.g. [`crate::stream::PqStream`].
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl AuthErrorImpl {
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
Self::Console(backend::AuthError::auth_failed(msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::RegisterError> for AuthErrorImpl {
|
||||
fn from(e: waiters::RegisterError) -> Self {
|
||||
Self::Console(backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::WaitError> for AuthErrorImpl {
|
||||
fn from(e: waiters::WaitError) -> Self {
|
||||
Self::Console(backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl<T> From<T> for AuthError
|
||||
where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
impl AuthError {
|
||||
pub fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::BadAuthMethod(name.into()).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
fn from(e: E) -> Self {
|
||||
Self(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
@@ -75,10 +78,14 @@ impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Console(e) => e.to_string_client(),
|
||||
Legacy(e) => e.to_string_client(),
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
WakeCompute(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
MalformedPassword(_) => self.to_string(),
|
||||
MissingProjectName => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
mod link;
|
||||
mod postgres;
|
||||
|
||||
pub mod console;
|
||||
mod link;
|
||||
pub use link::LinkAuthError;
|
||||
|
||||
mod console;
|
||||
pub use console::{GetAuthInfoError, WakeComputeError};
|
||||
|
||||
mod legacy_console;
|
||||
pub use legacy_console::{AuthError, AuthErrorImpl};
|
||||
pub use legacy_console::LegacyAuthError;
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
|
||||
@@ -13,21 +13,11 @@ use std::future::Future;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
|
||||
const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
#[error(transparent)]
|
||||
BadProjectName(#[from] auth::credentials::ClientCredsParseError),
|
||||
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Console responded with a malformed compute address: '{0}'")]
|
||||
BadComputeAddress(String),
|
||||
|
||||
#[error("Console responded with a malformed JSON: '{0}'")]
|
||||
pub enum TransportError {
|
||||
#[error("Console responded with a malformed JSON: {0}")]
|
||||
BadResponse(#[from] serde_json::Error),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
@@ -38,19 +28,72 @@ pub enum ConsoleAuthError {
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for ConsoleAuthError {
|
||||
impl UserFacingError for TransportError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ConsoleAuthError::*;
|
||||
use TransportError::*;
|
||||
match self {
|
||||
BadProjectName(e) => e.to_string_client(),
|
||||
_ => "Internal error".to_string(),
|
||||
HttpStatus(_) => self.to_string(),
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&auth::credentials::ClientCredsParseError> for ConsoleAuthError {
|
||||
fn from(e: &auth::credentials::ClientCredsParseError) -> Self {
|
||||
ConsoleAuthError::BadProjectName(e.clone())
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
impl From<reqwest::Error> for TransportError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(TransportError),
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
Transport(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<TransportError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::Transport(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
// We shouldn't show users the address even if it's broken.
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(TransportError),
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
Transport(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<TransportError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::Transport(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +138,7 @@ impl<'a> Api<'a> {
|
||||
handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo> {
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
@@ -105,21 +148,20 @@ impl<'a> Api<'a> {
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
|
||||
let resp = reqwest::get(url.into_inner()).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
return Err(TransportError::HttpStatus(resp.status()).into());
|
||||
}
|
||||
|
||||
let response: GetRoleSecretResponse =
|
||||
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
|
||||
let response: GetRoleSecretResponse = serde_json::from_str(&resp.text().await?)?;
|
||||
|
||||
scram::ServerSecret::parse(response.role_secret.as_str())
|
||||
scram::ServerSecret::parse(&response.role_secret)
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(ConsoleAuthError::BadSecret)
|
||||
.ok_or(GetAuthInfoError::BadSecret)
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_wake_compute");
|
||||
url.query_pairs_mut()
|
||||
@@ -128,17 +170,16 @@ impl<'a> Api<'a> {
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
|
||||
let resp = reqwest::get(url.into_inner()).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
return Err(TransportError::HttpStatus(resp.status()).into());
|
||||
}
|
||||
|
||||
let response: GetWakeComputeResponse =
|
||||
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
|
||||
let response: GetWakeComputeResponse = serde_json::from_str(&resp.text().await?)?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&response.address) {
|
||||
None => return Err(ConsoleAuthError::BadComputeAddress(response.address)),
|
||||
None => return Err(WakeComputeError::BadComputeAddress(response.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
@@ -162,8 +203,8 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
|
||||
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
|
||||
) -> auth::Result<compute::NodeInfo>
|
||||
where
|
||||
GetAuthInfo: Future<Output = Result<AuthInfo>>,
|
||||
WakeCompute: Future<Output = Result<ComputeConnCfg>>,
|
||||
GetAuthInfo: Future<Output = Result<AuthInfo, GetAuthInfoError>>,
|
||||
WakeCompute: Future<Output = Result<ComputeConnCfg, WakeComputeError>>,
|
||||
{
|
||||
let auth_info = get_auth_info(endpoint).await?;
|
||||
|
||||
@@ -171,7 +212,7 @@ where
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = auth::Scram(&secret);
|
||||
|
||||
@@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::BeMessage as Be;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
pub enum LegacyAuthError {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
@@ -24,7 +24,7 @@ pub enum AuthErrorImpl {
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error("Console responded with a malformed JSON: {0}")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
BadResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
@@ -36,30 +36,10 @@ pub enum AuthErrorImpl {
|
||||
WaiterWait(#[from] waiters::WaitError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
/// Smart constructor for authentication error reported by `mgmt`.
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
Self(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for AuthError
|
||||
where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
Self(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
impl UserFacingError for LegacyAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
use LegacyAuthError::*;
|
||||
match self {
|
||||
AuthFailed(_) | HttpStatus(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
@@ -88,7 +68,7 @@ async fn authenticate_proxy_client(
|
||||
md5_response: &str,
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
) -> Result<DatabaseInfo, LegacyAuthError> {
|
||||
let mut url = auth_endpoint.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
@@ -102,17 +82,17 @@ async fn authenticate_proxy_client(
|
||||
// TODO: leverage `reqwest::Client` to reuse connections
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
|
||||
return Err(LegacyAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
let auth_info = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: {:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
let db_info = match auth_info {
|
||||
Ready { conn_info } => conn_info,
|
||||
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
|
||||
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
|
||||
Error { error } => return Err(LegacyAuthError::AuthFailed(error)),
|
||||
NotReady { .. } => waiter.await?.map_err(LegacyAuthError::AuthFailed)?,
|
||||
};
|
||||
|
||||
Ok(db_info)
|
||||
@@ -124,7 +104,7 @@ async fn handle_existing_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, auth::AuthError> {
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
let psql_session_id = super::link::new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
|
||||
@@ -1,7 +1,34 @@
|
||||
use crate::{auth, compute, stream::PqStream};
|
||||
use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LinkAuthError {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterWait(#[from] waiters::WaitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for LinkAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use LinkAuthError::*;
|
||||
match self {
|
||||
AuthFailed(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
@@ -34,7 +61,7 @@ pub async fn handle_user(
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`)
|
||||
waiter.await?.map_err(auth::AuthErrorImpl::auth_failed)
|
||||
waiter.await?.map_err(LinkAuthError::AuthFailed)
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::console::{self, AuthInfo, Result},
|
||||
backend::console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError},
|
||||
ClientCredentials,
|
||||
},
|
||||
compute::{self, ComputeConnCfg},
|
||||
@@ -20,6 +20,13 @@ pub(super) struct Api<'a> {
|
||||
creds: &'a ClientCredentials,
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
impl From<tokio_postgres::Error> for TransportError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
|
||||
@@ -36,21 +43,16 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
/// This implementation fetches the auth info from a local postgres instance.
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo> {
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls)
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client
|
||||
.query(query, &[&self.creds.user])
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
let rows = client.query(query, &[&self.creds.user]).await?;
|
||||
|
||||
match &rows[..] {
|
||||
// We can't get a secret if there's no such user.
|
||||
@@ -74,13 +76,13 @@ impl<'a> Api<'a> {
|
||||
}))
|
||||
})
|
||||
// Putting the secret into this message is a security hazard!
|
||||
.ok_or(console::ConsoleAuthError::BadSecret)
|
||||
.ok_or(GetAuthInfoError::BadSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't need to wake anything locally, so we just return the connection info.
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
|
||||
let mut config = ComputeConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
|
||||
@@ -75,13 +75,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
// The so-called "password" should contain a base64-encoded json.
|
||||
// We will use it later to route the client to their project.
|
||||
let bytes = base64::decode(password)
|
||||
.map_err(|_| AuthErrorImpl::MalformedPassword("bad encoding"))?;
|
||||
|
||||
let payload = serde_json::from_slice(&bytes)
|
||||
.map_err(|_| AuthErrorImpl::MalformedPassword("invalid payload"))?;
|
||||
let payload = PasswordHackPayload::parse(password)
|
||||
// If we ended up here and the payload is malformed, it means that
|
||||
// the user neither enabled SNI nor resorted to any other method
|
||||
// for passing the project name we rely on. We should show them
|
||||
// the most helpful error message and point to the documentation.
|
||||
.ok_or(AuthErrorImpl::MissingProjectName)?;
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
@@ -98,7 +97,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(AuthErrorImpl::auth_failed("method not supported").into());
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
let secret = self.state.0;
|
||||
|
||||
@@ -1,102 +1,46 @@
|
||||
//! Payload for ad hoc authentication method for clients that don't support SNI.
|
||||
//! See the `impl` for [`super::backend::BackendType<ClientCredentials>`].
|
||||
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
|
||||
|
||||
use serde::{de, Deserialize, Deserializer};
|
||||
use std::fmt;
|
||||
use bstr::ByteSlice;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Password {
|
||||
/// A regular string for utf-8 encoded passwords.
|
||||
Simple { password: String },
|
||||
|
||||
/// Password is base64-encoded because it may contain arbitrary byte sequences.
|
||||
Encoded {
|
||||
#[serde(rename = "password_", deserialize_with = "deserialize_base64")]
|
||||
password: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Password {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
Password::Simple { password } => password.as_ref(),
|
||||
Password::Encoded { password } => password.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PasswordHackPayload {
|
||||
pub project: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub password: Password,
|
||||
pub password: Vec<u8>,
|
||||
}
|
||||
|
||||
fn deserialize_base64<'a, D: Deserializer<'a>>(des: D) -> Result<Vec<u8>, D::Error> {
|
||||
// It's very tempting to replace this with
|
||||
//
|
||||
// ```
|
||||
// let base64: &str = Deserialize::deserialize(des)?;
|
||||
// base64::decode(base64).map_err(serde::de::Error::custom)
|
||||
// ```
|
||||
//
|
||||
// Unfortunately, we can't always deserialize into `&str`, so we'd
|
||||
// have to use an allocating `String` instead. Thus, visitor is better.
|
||||
struct Visitor;
|
||||
impl PasswordHackPayload {
|
||||
pub fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
// The format is `project=<utf-8>;<password-bytes>`.
|
||||
let mut iter = bytes.strip_prefix(b"project=")?.splitn_str(2, ";");
|
||||
let project = iter.next()?.to_str().ok()?.to_owned();
|
||||
let password = iter.next()?.to_owned();
|
||||
|
||||
impl<'de> de::Visitor<'de> for Visitor {
|
||||
type Value = Vec<u8>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||
base64::decode(v).map_err(de::Error::custom)
|
||||
}
|
||||
Some(Self { project, password })
|
||||
}
|
||||
|
||||
des.deserialize_str(Visitor)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::rstest;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_password() -> anyhow::Result<()> {
|
||||
let password: Password = serde_json::from_value(json!({
|
||||
"password": "foo",
|
||||
}))?;
|
||||
assert_eq!(password.as_ref(), "foo".as_bytes());
|
||||
fn parse_password_hack_payload() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let password: Password = serde_json::from_value(json!({
|
||||
"password_": base64::encode("foo"),
|
||||
}))?;
|
||||
assert_eq!(password.as_ref(), "foo".as_bytes());
|
||||
let bytes = b"project=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
let bytes = b"project=;";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.project, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
#[rstest]
|
||||
#[case("password", str::to_owned)]
|
||||
#[case("password_", base64::encode)]
|
||||
fn parse(#[case] key: &str, #[case] encode: fn(&'static str) -> String) -> anyhow::Result<()> {
|
||||
let (password, project) = ("password", "pie-in-the-sky");
|
||||
let payload = json!({
|
||||
"project": project,
|
||||
key: encode(password),
|
||||
});
|
||||
|
||||
let payload: PasswordHackPayload = serde_json::from_value(payload)?;
|
||||
assert_eq!(payload.password.as_ref(), password.as_bytes());
|
||||
assert_eq!(payload.project, project);
|
||||
|
||||
Ok(())
|
||||
let bytes = b"project=foobar;pass;word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.project, "foobar");
|
||||
assert_eq!(payload.password, b"pass;word");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
import json
|
||||
import base64
|
||||
import psycopg2
|
||||
|
||||
|
||||
def test_proxy_select_1(static_proxy):
|
||||
@@ -13,22 +12,14 @@ def test_password_hack(static_proxy):
|
||||
static_proxy.safe_psql(f"create role {user} with login password '{password}'",
|
||||
options='project=irrelevant')
|
||||
|
||||
def encode(s: str) -> str:
|
||||
return base64.b64encode(s.encode('utf-8')).decode('utf-8')
|
||||
|
||||
magic = encode(json.dumps({
|
||||
'project': 'irrelevant',
|
||||
'password': password,
|
||||
}))
|
||||
|
||||
# Note the format of `magic`!
|
||||
magic = f"project=irrelevant;{password}"
|
||||
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
|
||||
|
||||
magic = encode(json.dumps({
|
||||
'project': 'irrelevant',
|
||||
'password_': encode(password),
|
||||
}))
|
||||
|
||||
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
|
||||
# Must also check that invalid magic won't be accepted.
|
||||
with pytest.raises(psycopg2.errors.OperationalError):
|
||||
magic = "broken"
|
||||
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
|
||||
|
||||
|
||||
# Pass extra options to the server.
|
||||
|
||||
Reference in New Issue
Block a user