mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy] Refactor cplane API and add new console SCRAM auth API
Now proxy binary accepts `--auth-backend` CLI option, which determines
auth scheme and cluster routing method. Following backends are currently
implemented:
* legacy
old method, when username ends with `@zenith` it uses md5 auth dbname as
the cluster name; otherwise, it sends a login link and waits for the console
to call back
* console
new SCRAM-based console API; uses SNI info to select the destination
cluster
* postgres
uses postgres to select auth secrets of existing roles. Useful for local
testing
* link
sends login link for all usernames
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,3 +11,6 @@ test_output/
|
||||
# Coverage
|
||||
*.profraw
|
||||
*.profdata
|
||||
|
||||
*.key
|
||||
*.crt
|
||||
|
||||
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2040,6 +2040,7 @@ dependencies = [
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"url",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ thiserror = "1.0.30"
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
|
||||
tokio-rustls = "0.23.0"
|
||||
url = "2.2.2"
|
||||
|
||||
utils = { path = "../libs/utils" }
|
||||
metrics = { path = "../libs/metrics" }
|
||||
|
||||
33
proxy/README.md
Normal file
33
proxy/README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Proxy
|
||||
|
||||
Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme and cluster routing method. Following backends are currently implemented:
|
||||
|
||||
* legacy
|
||||
old method, when username ends with `@zenith` it uses md5 auth dbname as the cluster name; otherwise, it sends a login link and waits for the console to call back
|
||||
* console
|
||||
new SCRAM-based console API; uses SNI info to select the destination cluster
|
||||
* postgres
|
||||
uses postgres to select auth secrets of existing roles. Useful for local testing
|
||||
* link
|
||||
sends login link for all usernames
|
||||
|
||||
## Using SNI-based routing on localhost
|
||||
|
||||
Now proxy determines cluster name from the subdomain, request to the `my-cluster-42.somedomain.tld` will be routed to the cluster named `my-cluster-42`. Unfortunately `/etc/hosts` does not support domain wildcards, so I usually use `*.localtest.me` which resolves to `127.0.0.1`. Now we can create self-signed certificate and play with proxy:
|
||||
|
||||
```
|
||||
openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.localtest.me"
|
||||
|
||||
```
|
||||
|
||||
now you can start proxy:
|
||||
|
||||
```
|
||||
./target/debug/proxy -c server.crt -k server.key
|
||||
```
|
||||
|
||||
and connect to it:
|
||||
|
||||
```
|
||||
PGSSLROOTCERT=./server.crt psql 'postgres://my-cluster-42.localtest.me:1234?sslmode=verify-full'
|
||||
```
|
||||
@@ -1,14 +1,14 @@
|
||||
mod credentials;
|
||||
mod flow;
|
||||
|
||||
use crate::config::{CloudApi, ProxyConfig};
|
||||
use crate::auth_backend::{console, legacy_console, link, postgres};
|
||||
use crate::config::{AuthBackendType, ProxyConfig};
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use crate::{cloud, compute, waiters};
|
||||
use crate::{auth_backend, compute, waiters};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
pub use credentials::ClientCredentials;
|
||||
pub use flow::*;
|
||||
@@ -18,13 +18,10 @@ pub use flow::*;
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
#[error(transparent)]
|
||||
Console(#[from] cloud::AuthError),
|
||||
Console(#[from] auth_backend::AuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] cloud::api::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] cloud::api::WakeComputeError),
|
||||
GetAuthInfo(#[from] auth_backend::console::ConsoleAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
@@ -40,19 +37,19 @@ pub enum AuthErrorImpl {
|
||||
|
||||
impl AuthErrorImpl {
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
AuthErrorImpl::Console(cloud::AuthError::auth_failed(msg))
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::auth_failed(msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::RegisterError> for AuthErrorImpl {
|
||||
fn from(e: waiters::RegisterError) -> Self {
|
||||
AuthErrorImpl::Console(cloud::AuthError::from(e))
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::WaitError> for AuthErrorImpl {
|
||||
fn from(e: waiters::WaitError) -> Self {
|
||||
AuthErrorImpl::Console(cloud::AuthError::from(e))
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,131 +79,25 @@ impl UserFacingError for AuthError {
|
||||
|
||||
async fn handle_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
if creds.is_existing_user() {
|
||||
match &config.cloud_endpoint {
|
||||
CloudApi::V1(api) => handle_existing_user_v1(api, client, creds).await,
|
||||
CloudApi::V2(api) => handle_existing_user_v2(api.as_ref(), client, creds).await,
|
||||
}
|
||||
} else {
|
||||
let redirect_uri = config.redirect_uri.as_ref();
|
||||
handle_new_user(redirect_uri, client).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate user via a legacy cloud API endpoint.
|
||||
async fn handle_existing_user_v1(
|
||||
cloud: &cloud::Legacy,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
let msg = client.read_password_message().await?;
|
||||
let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
let db_info = cloud
|
||||
.authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Authenticate user via a new cloud API endpoint which supports SCRAM.
|
||||
async fn handle_existing_user_v2(
|
||||
cloud: &(impl cloud::Api + ?Sized),
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
let auth_info = cloud.get_auth_info(&creds).await?;
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
cloud::api::AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
cloud::api::AuthInfo::Scram(secret) => {
|
||||
let scram = Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: cloud.wake_compute(&creds).await?,
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_new_user(
|
||||
redirect_uri: &str,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(redirect_uri, &psql_session_id);
|
||||
|
||||
let db_info = cloud::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`)
|
||||
waiter.await?.map_err(AuthErrorImpl::auth_failed)
|
||||
})
|
||||
.await?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
fn parse_password(bytes: &[u8]) -> Option<&str> {
|
||||
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Neon!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
match config.auth_backend {
|
||||
AuthBackendType::LegacyConsole => {
|
||||
legacy_console::handle_user(
|
||||
&config.auth_endpoint,
|
||||
&config.auth_link_uri,
|
||||
client,
|
||||
&creds,
|
||||
)
|
||||
.await
|
||||
}
|
||||
AuthBackendType::Console => {
|
||||
console::handle_user(config.auth_endpoint.as_ref(), client, &creds).await
|
||||
}
|
||||
AuthBackendType::Postgres => {
|
||||
postgres::handle_user(&config.auth_endpoint, client, &creds).await
|
||||
}
|
||||
AuthBackendType::Link => link::handle_user(config.auth_link_uri.as_ref(), client).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,10 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
|
||||
// New console API requires SNI info to determine cluster name.
|
||||
// Other Auth backends don't need it.
|
||||
pub sni_cluster: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
@@ -45,7 +49,11 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
Ok(Self {
|
||||
user,
|
||||
dbname: db,
|
||||
sni_cluster: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +62,7 @@ impl ClientCredentials {
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
// This method is just a convenient facade for `handle_user`
|
||||
super::handle_user(config, client, self).await
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
mod local;
|
||||
pub mod console;
|
||||
pub mod legacy_console;
|
||||
pub mod link;
|
||||
pub mod postgres;
|
||||
|
||||
mod legacy;
|
||||
pub use legacy::{AuthError, AuthErrorImpl, Legacy};
|
||||
|
||||
pub mod api;
|
||||
pub use api::{Api, BoxedApi};
|
||||
pub use legacy_console::{AuthError, AuthErrorImpl};
|
||||
|
||||
use crate::mgmt;
|
||||
use crate::waiters::{self, Waiter, Waiters};
|
||||
@@ -30,17 +29,3 @@ where
|
||||
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Construct a new opaque cloud API provider.
|
||||
pub fn new(url: reqwest::Url) -> anyhow::Result<BoxedApi> {
|
||||
Ok(match url.scheme() {
|
||||
"https" | "http" => {
|
||||
todo!("build a real cloud wrapper")
|
||||
}
|
||||
"postgresql" | "postgres" | "pg" => {
|
||||
// Just point to a local running postgres instance.
|
||||
Box::new(local::Local { url })
|
||||
}
|
||||
other => anyhow::bail!("unsupported url scheme: {other}"),
|
||||
})
|
||||
}
|
||||
236
proxy/src/auth_backend/console.rs
Normal file
236
proxy/src/auth_backend/console.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
//! Declaration of Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
compute, scram,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::stream::PqStream;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Bad client credentials: {0:?}")]
|
||||
BadCredentials(crate::auth::ClientCredentials),
|
||||
|
||||
/// For passwords that couldn't be processed by [`parse_password`].
|
||||
#[error("Absend SNI information")]
|
||||
SniMissing,
|
||||
|
||||
#[error(transparent)]
|
||||
BadUrl(#[from] url::ParseError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error("Console responded with a malformed JSON: '{0}'")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error("Console responded with a malformed compute address: '{0}'")]
|
||||
MalformedComputeAddress(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetRoleSecretResponse {
|
||||
role_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetWakeComputeResponse {
|
||||
address: String,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
|
||||
/// [Cloud API V1](super::legacy) returns cleartext password,
|
||||
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
|
||||
/// authentication, so we can leverage this method and cope without password.
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_auth_info(
|
||||
auth_endpoint: &str,
|
||||
user: &str,
|
||||
cluster: &str,
|
||||
) -> Result<AuthInfo, ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?;
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("cluster", cluster)
|
||||
.append_pair("role", user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
|
||||
scram::ServerSecret::parse(response.role_secret.as_str())
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(ConsoleAuthError::BadSecret)
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
auth_endpoint: &str,
|
||||
cluster: &str,
|
||||
) -> Result<(String, u16), ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?;
|
||||
url.query_pairs_mut().append_pair("cluster", cluster);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
let (host, port) = response
|
||||
.address
|
||||
.split_once(':')
|
||||
.ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
let port: u16 = port
|
||||
.parse()
|
||||
.map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
|
||||
Ok((host.to_string(), port))
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &str,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
let cluster = creds
|
||||
.sni_cluster
|
||||
.as_ref()
|
||||
.ok_or(ConsoleAuthError::SniMissing)?;
|
||||
let user = creds.user.as_str();
|
||||
|
||||
// Step 1: get the auth secret
|
||||
let auth_info = get_auth_info(auth_endpoint, user, cluster).await?;
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
// Step 2: wake compute
|
||||
let (host, port) = wake_compute(auth_endpoint, cluster).await?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: creds.dbname.clone(),
|
||||
user: creds.user.clone(),
|
||||
password: None,
|
||||
},
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
206
proxy/src/auth_backend/legacy_console.rs
Normal file
206
proxy/src/auth_backend/legacy_console.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
//! Cloud API V1.
|
||||
|
||||
use super::console::DatabaseInfo;
|
||||
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::stream::PqStream;
|
||||
|
||||
use crate::{compute, waiters};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error("Console responded with a malformed JSON: {0}")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterWait(#[from] waiters::WaitError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
/// Smart constructor for authentication error reported by `mgmt`.
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for AuthError
|
||||
where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
AuthError(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
AuthFailed(_) | HttpStatus(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the order of constructors is important.
|
||||
// https://serde.rs/enum-representations.html#untagged
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum ProxyAuthResponse {
|
||||
Ready { conn_info: DatabaseInfo },
|
||||
Error { error: String },
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
async fn authenticate_proxy_client(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
creds: &ClientCredentials,
|
||||
md5_response: &str,
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
let mut url = auth_endpoint.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("md5response", md5_response)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
|
||||
super::with_waiter(psql_session_id, |waiter| async {
|
||||
println!("cloud request: {}", url);
|
||||
// TODO: leverage `reqwest::Client` to reuse connections
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
let db_info = match auth_info {
|
||||
Ready { conn_info } => conn_info,
|
||||
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
|
||||
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
|
||||
};
|
||||
|
||||
Ok(db_info)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_existing_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
|
||||
let psql_session_id = super::link::new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
let msg = client.read_password_message().await?;
|
||||
let md5_response = parse_password(&msg).ok_or(crate::auth::AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
let db_info = authenticate_proxy_client(
|
||||
auth_endpoint,
|
||||
creds,
|
||||
md5_response,
|
||||
&md5_salt,
|
||||
&psql_session_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
auth_link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
|
||||
if creds.is_existing_user() {
|
||||
handle_existing_user(auth_endpoint, client, creds).await
|
||||
} else {
|
||||
super::link::handle_user(auth_link_uri.as_ref(), client).await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_password(bytes: &[u8]) -> Option<&str> {
|
||||
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_proxy_auth_response() {
|
||||
// Ready
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": true,
|
||||
"conn_info": DatabaseInfo::default(),
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
auth,
|
||||
ProxyAuthResponse::Ready {
|
||||
conn_info: DatabaseInfo { .. }
|
||||
}
|
||||
));
|
||||
|
||||
// Error
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": false,
|
||||
"error": "too bad, so sad",
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(auth, ProxyAuthResponse::Error { .. }));
|
||||
|
||||
// NotReady
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": false,
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
|
||||
}
|
||||
}
|
||||
52
proxy/src/auth_backend/link.rs
Normal file
52
proxy/src/auth_backend/link.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use crate::{compute, stream::PqStream};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Neon!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
redirect_uri: &str,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(redirect_uri, &psql_session_id);
|
||||
|
||||
let db_info = crate::auth_backend::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`)
|
||||
waiter
|
||||
.await?
|
||||
.map_err(crate::auth::AuthErrorImpl::auth_failed)
|
||||
})
|
||||
.await?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info,
|
||||
scram_keys: None,
|
||||
})
|
||||
}
|
||||
93
proxy/src/auth_backend/postgres.rs
Normal file
93
proxy/src/auth_backend/postgres.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
|
||||
use super::console::{self, AuthInfo, DatabaseInfo};
|
||||
use crate::scram;
|
||||
use crate::{auth::ClientCredentials, compute};
|
||||
|
||||
use crate::stream::PqStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
async fn get_auth_info(
|
||||
auth_endpoint: &str,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<AuthInfo, console::ConsoleAuthError> {
|
||||
// We wrap `tokio_postgres::Error` because we don't want to infect the
|
||||
// method's error type with a detail that's specific to debug mode only.
|
||||
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
|
||||
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) = tokio_postgres::connect(auth_endpoint, tokio_postgres::NoTls)
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client
|
||||
.query(query, &[&creds.user])
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
match &rows[..] {
|
||||
// We can't get a secret if there's no such user.
|
||||
[] => Err(console::ConsoleAuthError::BadCredentials(creds.to_owned())),
|
||||
// We shouldn't get more than one row anyway.
|
||||
[row, ..] => {
|
||||
let entry = row.try_get(0).map_err(io_error)?;
|
||||
scram::ServerSecret::parse(entry)
|
||||
.map(AuthInfo::Scram)
|
||||
.or_else(|| {
|
||||
// It could be an md5 hash if it's not a SCRAM secret.
|
||||
let text = entry.strip_prefix("md5")?;
|
||||
Some(AuthInfo::Md5({
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
bytes
|
||||
}))
|
||||
})
|
||||
// Putting the secret into this message is a security hazard!
|
||||
.ok_or(console::ConsoleAuthError::BadSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
let auth_info = get_auth_info(auth_endpoint.as_ref(), creds).await?;
|
||||
|
||||
let flow = crate::auth::AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = crate::auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: DatabaseInfo {
|
||||
// TODO: handle that near CLI params parsing
|
||||
host: auth_endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
port: auth_endpoint.port().unwrap_or(5432),
|
||||
dbname: creds.dbname.to_owned(),
|
||||
user: creds.user.to_owned(),
|
||||
password: None,
|
||||
},
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
//! Declaration of Cloud API V2.
|
||||
|
||||
use crate::{auth, scram};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Bad client credentials: {0:?}")]
|
||||
BadCredentials(crate::auth::ClientCredentials),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
// TODO: convert to an enum and describe possible sub-errors (see above)
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Failed to wake up the compute node")]
|
||||
pub struct WakeComputeError;
|
||||
|
||||
/// Opaque implementation of Cloud API.
|
||||
pub type BoxedApi = Box<dyn Api + Send + Sync>;
|
||||
|
||||
/// Cloud API methods required by the proxy.
|
||||
#[async_trait]
|
||||
pub trait Api {
|
||||
/// Get authentication information for the given user.
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
creds: &auth::ClientCredentials,
|
||||
) -> Result<AuthInfo, GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
creds: &auth::ClientCredentials,
|
||||
) -> Result<DatabaseInfo, WakeComputeError>;
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
|
||||
/// [Cloud API V1](super::legacy) returns cleartext password,
|
||||
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
|
||||
/// authentication, so we can leverage this method and cope without password.
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
//! Cloud API V1.
|
||||
|
||||
use super::api::DatabaseInfo;
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::waiters;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Neon cloud API provider.
|
||||
pub struct Legacy {
|
||||
auth_endpoint: reqwest::Url,
|
||||
}
|
||||
|
||||
impl Legacy {
|
||||
/// Construct a new legacy cloud API provider.
|
||||
pub fn new(auth_endpoint: reqwest::Url) -> Self {
|
||||
Self { auth_endpoint }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error("Console responded with a malformed JSON: {0}")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterWait(#[from] waiters::WaitError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
/// Smart constructor for authentication error reported by `mgmt`.
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for AuthError
|
||||
where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
AuthError(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
AuthFailed(_) | HttpStatus(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the order of constructors is important.
|
||||
// https://serde.rs/enum-representations.html#untagged
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum ProxyAuthResponse {
|
||||
Ready { conn_info: DatabaseInfo },
|
||||
Error { error: String },
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
impl Legacy {
|
||||
pub async fn authenticate_proxy_client(
|
||||
&self,
|
||||
creds: ClientCredentials,
|
||||
md5_response: &str,
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
let mut url = self.auth_endpoint.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("md5response", md5_response)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
|
||||
super::with_waiter(psql_session_id, |waiter| async {
|
||||
println!("cloud request: {}", url);
|
||||
// TODO: leverage `reqwest::Client` to reuse connections
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
let db_info = match auth_info {
|
||||
Ready { conn_info } => conn_info,
|
||||
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
|
||||
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
|
||||
};
|
||||
|
||||
Ok(db_info)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_proxy_auth_response() {
|
||||
// Ready
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": true,
|
||||
"conn_info": DatabaseInfo::default(),
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
auth,
|
||||
ProxyAuthResponse::Ready {
|
||||
conn_info: DatabaseInfo { .. }
|
||||
}
|
||||
));
|
||||
|
||||
// Error
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": false,
|
||||
"error": "too bad, so sad",
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(auth, ProxyAuthResponse::Error { .. }));
|
||||
|
||||
// NotReady
|
||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||
"ready": false,
|
||||
}))
|
||||
.unwrap();
|
||||
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
|
||||
use super::api::{self, Api, AuthInfo, DatabaseInfo};
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::scram;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Mocked cloud for testing purposes.
|
||||
pub struct Local {
|
||||
/// Database url, e.g. `postgres://user:password@localhost:5432/database`.
|
||||
pub url: reqwest::Url,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Api for Local {
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<AuthInfo, api::GetAuthInfoError> {
|
||||
// We wrap `tokio_postgres::Error` because we don't want to infect the
|
||||
// method's error type with a detail that's specific to debug mode only.
|
||||
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
|
||||
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.url.as_str(), tokio_postgres::NoTls)
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client
|
||||
.query(query, &[&creds.user])
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
match &rows[..] {
|
||||
// We can't get a secret if there's no such user.
|
||||
[] => Err(api::GetAuthInfoError::BadCredentials(creds.to_owned())),
|
||||
// We shouldn't get more than one row anyway.
|
||||
[row, ..] => {
|
||||
let entry = row.try_get(0).map_err(io_error)?;
|
||||
scram::ServerSecret::parse(entry)
|
||||
.map(AuthInfo::Scram)
|
||||
.or_else(|| {
|
||||
// It could be an md5 hash if it's not a SCRAM secret.
|
||||
let text = entry.strip_prefix("md5")?;
|
||||
Some(AuthInfo::Md5({
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
bytes
|
||||
}))
|
||||
})
|
||||
// Putting the secret into this message is a security hazard!
|
||||
.ok_or(api::GetAuthInfoError::BadSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<DatabaseInfo, api::WakeComputeError> {
|
||||
// Local setup doesn't have a dedicated compute node,
|
||||
// so we just return the local database we're pointed at.
|
||||
Ok(DatabaseInfo {
|
||||
host: self.url.host_str().unwrap_or("localhost").to_owned(),
|
||||
port: self.url.port().unwrap_or(5432),
|
||||
dbname: creds.dbname.to_owned(),
|
||||
user: creds.user.to_owned(),
|
||||
password: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::auth_backend::console::DatabaseInfo;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::cloud::api::DatabaseInfo;
|
||||
use crate::error::UserFacingError;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
@@ -1,35 +1,39 @@
|
||||
use crate::cloud;
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use std::sync::Arc;
|
||||
use anyhow::{ensure, Context};
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum AuthBackendType {
|
||||
LegacyConsole,
|
||||
Console,
|
||||
Postgres,
|
||||
Link,
|
||||
}
|
||||
|
||||
impl FromStr for AuthBackendType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
println!("ClientAuthMethod::from_str: '{}'", s);
|
||||
use AuthBackendType::*;
|
||||
match s {
|
||||
"legacy" => Ok(LegacyConsole),
|
||||
"console" => Ok(Console),
|
||||
"postgres" => Ok(Postgres),
|
||||
"link" => Ok(Link),
|
||||
_ => Err(anyhow::anyhow!("Invlid option for auth method")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProxyConfig {
|
||||
/// Unauthenticated users will be redirected to this URL.
|
||||
pub redirect_uri: reqwest::Url,
|
||||
|
||||
/// Cloud API endpoint for user authentication.
|
||||
pub cloud_endpoint: CloudApi,
|
||||
|
||||
/// TLS configuration for the proxy.
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
/// Cloud API configuration.
|
||||
pub enum CloudApi {
|
||||
/// We'll drop this one when [`CloudApi::V2`] is stable.
|
||||
V1(crate::cloud::Legacy),
|
||||
/// The new version of the cloud API.
|
||||
V2(crate::cloud::BoxedApi),
|
||||
}
|
||||
pub auth_backend: AuthBackendType,
|
||||
|
||||
impl CloudApi {
|
||||
/// Configure Cloud API provider.
|
||||
pub fn new(version: &str, url: reqwest::Url) -> anyhow::Result<Self> {
|
||||
Ok(match version {
|
||||
"v1" => Self::V1(cloud::Legacy::new(url)),
|
||||
"v2" => Self::V2(cloud::new(url)?),
|
||||
_ => bail!("unknown cloud API version: {}", version),
|
||||
})
|
||||
}
|
||||
pub auth_endpoint: reqwest::Url,
|
||||
|
||||
pub auth_link_uri: reqwest::Url,
|
||||
}
|
||||
|
||||
pub type TlsConfig = Arc<rustls::ServerConfig>;
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
|
||||
mod auth;
|
||||
mod auth_backend;
|
||||
mod cancellation;
|
||||
mod cloud;
|
||||
mod compute;
|
||||
mod config;
|
||||
mod error;
|
||||
@@ -48,18 +48,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("auth-method")
|
||||
.long("auth-method")
|
||||
Arg::new("auth-backend")
|
||||
.long("auth-backend")
|
||||
.takes_value(true)
|
||||
.help("Possible values: password | link | mixed")
|
||||
.default_value("mixed"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("static-router")
|
||||
.short('s')
|
||||
.long("static-router")
|
||||
.takes_value(true)
|
||||
.help("Route all clients to host:port"),
|
||||
.help("Possible values: legacy | console | postgres | link")
|
||||
.default_value("legacy"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("mgmt")
|
||||
@@ -82,7 +75,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
.short('u')
|
||||
.long("uri")
|
||||
.takes_value(true)
|
||||
.help("redirect unauthenticated users to given uri")
|
||||
.help("redirect unauthenticated users to the given uri in case of link auth")
|
||||
.default_value("http://localhost:3000/psql_session/"),
|
||||
)
|
||||
.arg(
|
||||
@@ -93,14 +86,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
.help("cloud API endpoint for authenticating users")
|
||||
.default_value("http://localhost:3000/authenticate_proxy_request/"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("api-version")
|
||||
.long("api-version")
|
||||
.takes_value(true)
|
||||
.default_value("v1")
|
||||
.possible_values(["v1", "v2"])
|
||||
.help("cloud API version to be used for authentication"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-key")
|
||||
.short('k')
|
||||
@@ -132,15 +117,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
|
||||
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
|
||||
|
||||
let cloud_endpoint = config::CloudApi::new(
|
||||
arg_matches.value_of("api-version").unwrap(),
|
||||
arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
)?;
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
cloud_endpoint,
|
||||
tls_config,
|
||||
auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?,
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
}));
|
||||
|
||||
println!("Version: {}", GIT_VERSION);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::cloud;
|
||||
use crate::auth_backend;
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
@@ -10,6 +10,8 @@ use utils::{
|
||||
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
|
||||
};
|
||||
|
||||
/// TODO: move all of that to auth-backend/link.rs when we ditch legacy-console backend
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
@@ -75,12 +77,12 @@ struct PsqlSessionResponse {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum PsqlSessionResult {
|
||||
Success(cloud::api::DatabaseInfo),
|
||||
Success(auth_backend::console::DatabaseInfo),
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<cloud::api::DatabaseInfo, String>;
|
||||
pub type ComputeReady = Result<auth_backend::console::DatabaseInfo, String>;
|
||||
|
||||
impl PsqlSessionResult {
|
||||
fn into_compute_ready(self) -> ComputeReady {
|
||||
@@ -111,7 +113,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
|
||||
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
|
||||
|
||||
match cloud::notify(&resp.session_id, resp.result.into_compute_ready()) {
|
||||
match auth_backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
|
||||
@@ -73,7 +73,7 @@ pub async fn thread_main(
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
@@ -148,6 +148,8 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
// TODO: set creds.cluster here when SNI info is available
|
||||
|
||||
break Ok(Some((stream, creds)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
@@ -174,7 +176,7 @@ impl<S> Client<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
|
||||
@@ -38,6 +38,7 @@ impl ServerSecret {
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
#[allow(dead_code)]
|
||||
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
|
||||
// Refer to `auth-scram.c : scram_mock_salt`.
|
||||
let mocked_salt = super::sha256([user.as_bytes(), nonce]);
|
||||
|
||||
@@ -1382,8 +1382,8 @@ def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]:
|
||||
class ZenithProxy(PgProtocol):
|
||||
def __init__(self, port: int):
|
||||
super().__init__(host="127.0.0.1",
|
||||
user="pytest",
|
||||
password="pytest",
|
||||
user="proxy_user",
|
||||
password="pytest2",
|
||||
port=port,
|
||||
dbname='postgres')
|
||||
self.http_port = 7001
|
||||
@@ -1399,8 +1399,8 @@ class ZenithProxy(PgProtocol):
|
||||
args = [bin_proxy]
|
||||
args.extend(["--http", f"{self.host}:{self.http_port}"])
|
||||
args.extend(["--proxy", f"{self.host}:{self.port}"])
|
||||
args.extend(["--auth-method", "password"])
|
||||
args.extend(["--static-router", addr])
|
||||
args.extend(["--auth-backend", "postgres"])
|
||||
args.extend(["--auth-endpoint", "postgres://proxy_auth:pytest1@localhost:5432/postgres"])
|
||||
self._popen = subprocess.Popen(args)
|
||||
self._wait_until_ready()
|
||||
|
||||
@@ -1422,7 +1422,8 @@ class ZenithProxy(PgProtocol):
|
||||
def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
|
||||
"""Zenith proxy that routes directly to vanilla postgres."""
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user pytest with password 'pytest';")
|
||||
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
|
||||
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
|
||||
|
||||
with ZenithProxy(4432) as proxy:
|
||||
proxy.start_static()
|
||||
|
||||
Reference in New Issue
Block a user