[proxy] Refactor cplane API and add new console SCRAM auth API

Now proxy binary accepts `--auth-backend` CLI option, which determines
auth scheme and cluster routing method. Following backends are currently
implemented:

* legacy
    old method, when username ends with `@zenith` it uses md5 auth dbname as
    the cluster name; otherwise, it sends a login link and waits for the console
    to call back
* console
    new SCRAM-based console API; uses SNI info to select the destination
    cluster
* postgres
    uses postgres to select auth secrets of existing roles. Useful for local
    testing
* link
    sends login link for all usernames
This commit is contained in:
Stas Kelvich
2022-04-30 00:58:57 +03:00
parent af0195b604
commit 0323bb5870
21 changed files with 722 additions and 578 deletions

3
.gitignore vendored
View File

@@ -11,3 +11,6 @@ test_output/
# Coverage
*.profraw
*.profdata
*.key
*.crt

1
Cargo.lock generated
View File

@@ -2040,6 +2040,7 @@ dependencies = [
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"url",
"utils",
"workspace_hack",
]

View File

@@ -32,6 +32,7 @@ thiserror = "1.0.30"
tokio = { version = "1.17", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
tokio-rustls = "0.23.0"
url = "2.2.2"
utils = { path = "../libs/utils" }
metrics = { path = "../libs/metrics" }

33
proxy/README.md Normal file
View File

@@ -0,0 +1,33 @@
# Proxy
Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme and cluster routing method. Following backends are currently implemented:
* legacy
old method, when username ends with `@zenith` it uses md5 auth dbname as the cluster name; otherwise, it sends a login link and waits for the console to call back
* console
new SCRAM-based console API; uses SNI info to select the destination cluster
* postgres
uses postgres to select auth secrets of existing roles. Useful for local testing
* link
sends login link for all usernames
## Using SNI-based routing on localhost
Now proxy determines cluster name from the subdomain, request to the `my-cluster-42.somedomain.tld` will be routed to the cluster named `my-cluster-42`. Unfortunately `/etc/hosts` does not support domain wildcards, so I usually use `*.localtest.me` which resolves to `127.0.0.1`. Now we can create self-signed certificate and play with proxy:
```
openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.localtest.me"
```
now you can start proxy:
```
./target/debug/proxy -c server.crt -k server.key
```
and connect to it:
```
PGSSLROOTCERT=./server.crt psql 'postgres://my-cluster-42.localtest.me:1234?sslmode=verify-full'
```

View File

@@ -1,14 +1,14 @@
mod credentials;
mod flow;
use crate::config::{CloudApi, ProxyConfig};
use crate::auth_backend::{console, legacy_console, link, postgres};
use crate::config::{AuthBackendType, ProxyConfig};
use crate::error::UserFacingError;
use crate::stream::PqStream;
use crate::{cloud, compute, waiters};
use crate::{auth_backend, compute, waiters};
use std::io;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub use credentials::ClientCredentials;
pub use flow::*;
@@ -18,13 +18,10 @@ pub use flow::*;
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error(transparent)]
Console(#[from] cloud::AuthError),
Console(#[from] auth_backend::AuthError),
#[error(transparent)]
GetAuthInfo(#[from] cloud::api::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] cloud::api::WakeComputeError),
GetAuthInfo(#[from] auth_backend::console::ConsoleAuthError),
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -40,19 +37,19 @@ pub enum AuthErrorImpl {
impl AuthErrorImpl {
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthErrorImpl::Console(cloud::AuthError::auth_failed(msg))
AuthErrorImpl::Console(auth_backend::AuthError::auth_failed(msg))
}
}
impl From<waiters::RegisterError> for AuthErrorImpl {
fn from(e: waiters::RegisterError) -> Self {
AuthErrorImpl::Console(cloud::AuthError::from(e))
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
}
}
impl From<waiters::WaitError> for AuthErrorImpl {
fn from(e: waiters::WaitError) -> Self {
AuthErrorImpl::Console(cloud::AuthError::from(e))
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
}
}
@@ -82,131 +79,25 @@ impl UserFacingError for AuthError {
async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
if creds.is_existing_user() {
match &config.cloud_endpoint {
CloudApi::V1(api) => handle_existing_user_v1(api, client, creds).await,
CloudApi::V2(api) => handle_existing_user_v2(api.as_ref(), client, creds).await,
match config.auth_backend {
AuthBackendType::LegacyConsole => {
legacy_console::handle_user(
&config.auth_endpoint,
&config.auth_link_uri,
client,
&creds,
)
.await
}
} else {
let redirect_uri = config.redirect_uri.as_ref();
handle_new_user(redirect_uri, client).await
AuthBackendType::Console => {
console::handle_user(config.auth_endpoint.as_ref(), client, &creds).await
}
AuthBackendType::Postgres => {
postgres::handle_user(&config.auth_endpoint, client, &creds).await
}
AuthBackendType::Link => link::handle_user(config.auth_link_uri.as_ref(), client).await,
}
}
/// Authenticate user via a legacy cloud API endpoint.
async fn handle_existing_user_v1(
cloud: &cloud::Legacy,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let md5_salt = rand::random();
client
.write_message(&Be::AuthenticationMD5Password(md5_salt))
.await?;
// Read client's password hash
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let db_info = cloud
.authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}
/// Authenticate user via a new cloud API endpoint which supports SCRAM.
async fn handle_existing_user_v2(
cloud: &(impl cloud::Api + ?Sized),
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
let auth_info = cloud.get_auth_info(&creds).await?;
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
cloud::api::AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
cloud::api::AuthInfo::Scram(secret) => {
let scram = Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info: cloud.wake_compute(&creds).await?,
scram_keys,
})
}
async fn handle_new_user(
redirect_uri: &str,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<compute::NodeInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(redirect_uri, &psql_session_id);
let db_info = cloud::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
// Wait for web console response (see `mgmt`)
waiter.await?.map_err(AuthErrorImpl::auth_failed)
})
.await?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}
fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
fn parse_password(bytes: &[u8]) -> Option<&str> {
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Neon!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}

View File

@@ -23,6 +23,10 @@ impl UserFacingError for ClientCredsParseError {}
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
// New console API requires SNI info to determine cluster name.
// Other Auth backends don't need it.
pub sni_cluster: Option<String>,
}
impl ClientCredentials {
@@ -45,7 +49,11 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
Ok(Self {
user,
dbname: db,
sni_cluster: None,
})
}
}
@@ -54,7 +62,7 @@ impl ClientCredentials {
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> Result<compute::NodeInfo, AuthError> {
// This method is just a convenient facade for `handle_user`
super::handle_user(config, client, self).await

View File

@@ -1,10 +1,9 @@
mod local;
pub mod console;
pub mod legacy_console;
pub mod link;
pub mod postgres;
mod legacy;
pub use legacy::{AuthError, AuthErrorImpl, Legacy};
pub mod api;
pub use api::{Api, BoxedApi};
pub use legacy_console::{AuthError, AuthErrorImpl};
use crate::mgmt;
use crate::waiters::{self, Waiter, Waiters};
@@ -30,17 +29,3 @@ where
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Construct a new opaque cloud API provider.
pub fn new(url: reqwest::Url) -> anyhow::Result<BoxedApi> {
Ok(match url.scheme() {
"https" | "http" => {
todo!("build a real cloud wrapper")
}
"postgresql" | "postgres" | "pg" => {
// Just point to a local running postgres instance.
Box::new(local::Local { url })
}
other => anyhow::bail!("unsupported url scheme: {other}"),
})
}

View File

@@ -0,0 +1,236 @@
//! Declaration of Cloud API V2.
use crate::{
auth::{self, AuthFlow},
compute, scram,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::auth::ClientCredentials;
use crate::stream::PqStream;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
#[derive(Debug, Error)]
pub enum ConsoleAuthError {
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
BadSecret,
#[error("Bad client credentials: {0:?}")]
BadCredentials(crate::auth::ClientCredentials),
/// For passwords that couldn't be processed by [`parse_password`].
#[error("Absend SNI information")]
SniMissing,
#[error(transparent)]
BadUrl(#[from] url::ParseError),
#[error(transparent)]
Io(#[from] std::io::Error),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error("Console responded with a malformed JSON: '{0}'")]
MalformedResponse(#[from] serde_json::Error),
#[error("Console responded with a malformed compute address: '{0}'")]
MalformedComputeAddress(String),
}
#[derive(Serialize, Deserialize, Debug)]
struct GetRoleSecretResponse {
role_secret: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct GetWakeComputeResponse {
address: String,
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
/// Compute node connection params provided by the cloud.
/// Note how it implements serde traits, since we receive it over the wire.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// [Cloud API V1](super::legacy) returns cleartext password,
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
/// authentication, so we can leverage this method and cope without password.
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
async fn get_auth_info(
auth_endpoint: &str,
user: &str,
cluster: &str,
) -> Result<AuthInfo, ConsoleAuthError> {
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?;
url.query_pairs_mut()
.append_pair("cluster", cluster)
.append_pair("role", user);
// TODO: use a proper logger
println!("cplane request: {}", url);
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?;
scram::ServerSecret::parse(response.role_secret.as_str())
.map(AuthInfo::Scram)
.ok_or(ConsoleAuthError::BadSecret)
}
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
auth_endpoint: &str,
cluster: &str,
) -> Result<(String, u16), ConsoleAuthError> {
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?;
url.query_pairs_mut().append_pair("cluster", cluster);
// TODO: use a proper logger
println!("cplane request: {}", url);
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?;
let (host, port) = response
.address
.split_once(':')
.ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
let port: u16 = port
.parse()
.map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
Ok((host.to_string(), port))
}
pub async fn handle_user(
auth_endpoint: &str,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: &ClientCredentials,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
let cluster = creds
.sni_cluster
.as_ref()
.ok_or(ConsoleAuthError::SniMissing)?;
let user = creds.user.as_str();
// Step 1: get the auth secret
let auth_info = get_auth_info(auth_endpoint, user, cluster).await?;
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
AuthInfo::Scram(secret) => {
let scram = auth::Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
// Step 2: wake compute
let (host, port) = wake_compute(auth_endpoint, cluster).await?;
Ok(compute::NodeInfo {
db_info: DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: None,
},
scram_keys,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -0,0 +1,206 @@
//! Cloud API V1.
use super::console::DatabaseInfo;
use crate::auth::ClientCredentials;
use crate::stream::PqStream;
use crate::{compute, waiters};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
use thiserror::Error;
use crate::error::UserFacingError;
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error("Console responded with a malformed JSON: {0}")]
MalformedResponse(#[from] serde_json::Error),
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
#[error(transparent)]
WaiterWait(#[from] waiters::WaitError),
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
/// Smart constructor for authentication error reported by `mgmt`.
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
}
}
impl<T> From<T> for AuthError
where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
AuthFailed(_) | HttpStatus(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<DatabaseInfo, AuthError> {
let mut url = auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
super::with_waiter(psql_session_id, |waiter| async {
println!("cloud request: {}", url);
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
let db_info = match auth_info {
Ready { conn_info } => conn_info,
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
};
Ok(db_info)
})
.await
}
async fn handle_existing_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
let psql_session_id = super::link::new_psql_session_id();
let md5_salt = rand::random();
client
.write_message(&Be::AuthenticationMD5Password(md5_salt))
.await?;
// Read client's password hash
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(crate::auth::AuthErrorImpl::MalformedPassword)?;
let db_info = authenticate_proxy_client(
auth_endpoint,
creds,
md5_response,
&md5_salt,
&psql_session_id,
)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
if creds.is_existing_user() {
handle_existing_user(auth_endpoint, client, creds).await
} else {
super::link::handle_user(auth_link_uri.as_ref(), client).await
}
}
fn parse_password(bytes: &[u8]) -> Option<&str> {
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_proxy_auth_response() {
// Ready
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": true,
"conn_info": DatabaseInfo::default(),
}))
.unwrap();
assert!(matches!(
auth,
ProxyAuthResponse::Ready {
conn_info: DatabaseInfo { .. }
}
));
// Error
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": false,
"error": "too bad, so sad",
}))
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::Error { .. }));
// NotReady
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": false,
}))
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
}
}

View File

@@ -0,0 +1,52 @@
use crate::{compute, stream::PqStream};
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Neon!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}
pub fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub async fn handle_user(
redirect_uri: &str,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(redirect_uri, &psql_session_id);
let db_info = crate::auth_backend::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
// Wait for web console response (see `mgmt`)
waiter
.await?
.map_err(crate::auth::AuthErrorImpl::auth_failed)
})
.await?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}

View File

@@ -0,0 +1,93 @@
//! Local mock of Cloud API V2.
use super::console::{self, AuthInfo, DatabaseInfo};
use crate::scram;
use crate::{auth::ClientCredentials, compute};
use crate::stream::PqStream;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
async fn get_auth_info(
auth_endpoint: &str,
creds: &ClientCredentials,
) -> Result<AuthInfo, console::ConsoleAuthError> {
// We wrap `tokio_postgres::Error` because we don't want to infect the
// method's error type with a detail that's specific to debug mode only.
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) = tokio_postgres::connect(auth_endpoint, tokio_postgres::NoTls)
.await
.map_err(io_error)?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client
.query(query, &[&creds.user])
.await
.map_err(io_error)?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(console::ConsoleAuthError::BadCredentials(creds.to_owned())),
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(console::ConsoleAuthError::BadSecret)
}
}
}
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: &ClientCredentials,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
let auth_info = get_auth_info(auth_endpoint.as_ref(), creds).await?;
let flow = crate::auth::AuthFlow::new(client);
let scram_keys = match auth_info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
AuthInfo::Scram(secret) => {
let scram = crate::auth::Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info: DatabaseInfo {
// TODO: handle that near CLI params parsing
host: auth_endpoint.host_str().unwrap_or("localhost").to_owned(),
port: auth_endpoint.port().unwrap_or(5432),
dbname: creds.dbname.to_owned(),
user: creds.user.to_owned(),
password: None,
},
scram_keys,
})
}

View File

@@ -1,120 +0,0 @@
//! Declaration of Cloud API V2.
use crate::{auth, scram};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum GetAuthInfoError {
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
BadSecret,
#[error("Bad client credentials: {0:?}")]
BadCredentials(crate::auth::ClientCredentials),
#[error(transparent)]
Io(#[from] std::io::Error),
}
// TODO: convert to an enum and describe possible sub-errors (see above)
#[derive(Debug, Error)]
#[error("Failed to wake up the compute node")]
pub struct WakeComputeError;
/// Opaque implementation of Cloud API.
pub type BoxedApi = Box<dyn Api + Send + Sync>;
/// Cloud API methods required by the proxy.
#[async_trait]
pub trait Api {
/// Get authentication information for the given user.
async fn get_auth_info(
&self,
creds: &auth::ClientCredentials,
) -> Result<AuthInfo, GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
creds: &auth::ClientCredentials,
) -> Result<DatabaseInfo, WakeComputeError>;
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
/// Compute node connection params provided by the cloud.
/// Note how it implements serde traits, since we receive it over the wire.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// [Cloud API V1](super::legacy) returns cleartext password,
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
/// authentication, so we can leverage this method and cope without password.
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -1,160 +0,0 @@
//! Cloud API V1.
use super::api::DatabaseInfo;
use crate::auth::ClientCredentials;
use crate::error::UserFacingError;
use crate::waiters;
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Neon cloud API provider.
pub struct Legacy {
auth_endpoint: reqwest::Url,
}
impl Legacy {
/// Construct a new legacy cloud API provider.
pub fn new(auth_endpoint: reqwest::Url) -> Self {
Self { auth_endpoint }
}
}
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error("Console responded with a malformed JSON: {0}")]
MalformedResponse(#[from] serde_json::Error),
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
#[error(transparent)]
WaiterWait(#[from] waiters::WaitError),
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
/// Smart constructor for authentication error reported by `mgmt`.
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
}
}
impl<T> From<T> for AuthError
where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
AuthFailed(_) | HttpStatus(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl Legacy {
pub async fn authenticate_proxy_client(
&self,
creds: ClientCredentials,
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<DatabaseInfo, AuthError> {
let mut url = self.auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
super::with_waiter(psql_session_id, |waiter| async {
println!("cloud request: {}", url);
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
let db_info = match auth_info {
Ready { conn_info } => conn_info,
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
};
Ok(db_info)
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_proxy_auth_response() {
// Ready
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": true,
"conn_info": DatabaseInfo::default(),
}))
.unwrap();
assert!(matches!(
auth,
ProxyAuthResponse::Ready {
conn_info: DatabaseInfo { .. }
}
));
// Error
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": false,
"error": "too bad, so sad",
}))
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::Error { .. }));
// NotReady
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": false,
}))
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
}
}

View File

@@ -1,76 +0,0 @@
//! Local mock of Cloud API V2.
use super::api::{self, Api, AuthInfo, DatabaseInfo};
use crate::auth::ClientCredentials;
use crate::scram;
use async_trait::async_trait;
/// Mocked cloud for testing purposes.
pub struct Local {
/// Database url, e.g. `postgres://user:password@localhost:5432/database`.
pub url: reqwest::Url,
}
#[async_trait]
impl Api for Local {
async fn get_auth_info(
&self,
creds: &ClientCredentials,
) -> Result<AuthInfo, api::GetAuthInfoError> {
// We wrap `tokio_postgres::Error` because we don't want to infect the
// method's error type with a detail that's specific to debug mode only.
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.url.as_str(), tokio_postgres::NoTls)
.await
.map_err(io_error)?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client
.query(query, &[&creds.user])
.await
.map_err(io_error)?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(api::GetAuthInfoError::BadCredentials(creds.to_owned())),
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(api::GetAuthInfoError::BadSecret)
}
}
}
async fn wake_compute(
&self,
creds: &ClientCredentials,
) -> Result<DatabaseInfo, api::WakeComputeError> {
// Local setup doesn't have a dedicated compute node,
// so we just return the local database we're pointed at.
Ok(DatabaseInfo {
host: self.url.host_str().unwrap_or("localhost").to_owned(),
port: self.url.port().unwrap_or(5432),
dbname: creds.dbname.to_owned(),
user: creds.user.to_owned(),
password: None,
})
}
}

View File

@@ -1,5 +1,5 @@
use crate::auth_backend::console::DatabaseInfo;
use crate::cancellation::CancelClosure;
use crate::cloud::api::DatabaseInfo;
use crate::error::UserFacingError;
use std::io;
use std::net::SocketAddr;

View File

@@ -1,35 +1,39 @@
use crate::cloud;
use anyhow::{bail, ensure, Context};
use std::sync::Arc;
use anyhow::{ensure, Context};
use std::{str::FromStr, sync::Arc};
#[non_exhaustive]
pub enum AuthBackendType {
LegacyConsole,
Console,
Postgres,
Link,
}
impl FromStr for AuthBackendType {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
println!("ClientAuthMethod::from_str: '{}'", s);
use AuthBackendType::*;
match s {
"legacy" => Ok(LegacyConsole),
"console" => Ok(Console),
"postgres" => Ok(Postgres),
"link" => Ok(Link),
_ => Err(anyhow::anyhow!("Invlid option for auth method")),
}
}
}
pub struct ProxyConfig {
/// Unauthenticated users will be redirected to this URL.
pub redirect_uri: reqwest::Url,
/// Cloud API endpoint for user authentication.
pub cloud_endpoint: CloudApi,
/// TLS configuration for the proxy.
pub tls_config: Option<TlsConfig>,
}
/// Cloud API configuration.
pub enum CloudApi {
/// We'll drop this one when [`CloudApi::V2`] is stable.
V1(crate::cloud::Legacy),
/// The new version of the cloud API.
V2(crate::cloud::BoxedApi),
}
pub auth_backend: AuthBackendType,
impl CloudApi {
/// Configure Cloud API provider.
pub fn new(version: &str, url: reqwest::Url) -> anyhow::Result<Self> {
Ok(match version {
"v1" => Self::V1(cloud::Legacy::new(url)),
"v2" => Self::V2(cloud::new(url)?),
_ => bail!("unknown cloud API version: {}", version),
})
}
pub auth_endpoint: reqwest::Url,
pub auth_link_uri: reqwest::Url,
}
pub type TlsConfig = Arc<rustls::ServerConfig>;

View File

@@ -5,8 +5,8 @@
//! in somewhat transparent manner (again via communication with control plane API).
mod auth;
mod auth_backend;
mod cancellation;
mod cloud;
mod compute;
mod config;
mod error;
@@ -48,18 +48,11 @@ async fn main() -> anyhow::Result<()> {
.default_value("127.0.0.1:4432"),
)
.arg(
Arg::new("auth-method")
.long("auth-method")
Arg::new("auth-backend")
.long("auth-backend")
.takes_value(true)
.help("Possible values: password | link | mixed")
.default_value("mixed"),
)
.arg(
Arg::new("static-router")
.short('s')
.long("static-router")
.takes_value(true)
.help("Route all clients to host:port"),
.help("Possible values: legacy | console | postgres | link")
.default_value("legacy"),
)
.arg(
Arg::new("mgmt")
@@ -82,7 +75,7 @@ async fn main() -> anyhow::Result<()> {
.short('u')
.long("uri")
.takes_value(true)
.help("redirect unauthenticated users to given uri")
.help("redirect unauthenticated users to the given uri in case of link auth")
.default_value("http://localhost:3000/psql_session/"),
)
.arg(
@@ -93,14 +86,6 @@ async fn main() -> anyhow::Result<()> {
.help("cloud API endpoint for authenticating users")
.default_value("http://localhost:3000/authenticate_proxy_request/"),
)
.arg(
Arg::new("api-version")
.long("api-version")
.takes_value(true)
.default_value("v1")
.possible_values(["v1", "v2"])
.help("cloud API version to be used for authentication"),
)
.arg(
Arg::new("tls-key")
.short('k')
@@ -132,15 +117,11 @@ async fn main() -> anyhow::Result<()> {
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
let cloud_endpoint = config::CloudApi::new(
arg_matches.value_of("api-version").unwrap(),
arg_matches.value_of("auth-endpoint").unwrap().parse()?,
)?;
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
cloud_endpoint,
tls_config,
auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
}));
println!("Version: {}", GIT_VERSION);

View File

@@ -1,4 +1,4 @@
use crate::cloud;
use crate::auth_backend;
use anyhow::Context;
use serde::Deserialize;
use std::{
@@ -10,6 +10,8 @@ use utils::{
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};
/// TODO: move all of that to auth-backend/link.rs when we ditch legacy-console backend
///
/// Main proxy listener loop.
///
@@ -75,12 +77,12 @@ struct PsqlSessionResponse {
#[derive(Deserialize)]
enum PsqlSessionResult {
Success(cloud::api::DatabaseInfo),
Success(auth_backend::console::DatabaseInfo),
Failure(String),
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<cloud::api::DatabaseInfo, String>;
pub type ComputeReady = Result<auth_backend::console::DatabaseInfo, String>;
impl PsqlSessionResult {
fn into_compute_ready(self) -> ComputeReady {
@@ -111,7 +113,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
match cloud::notify(&resp.session_id, resp.result.into_compute_ready()) {
match auth_backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -73,7 +73,7 @@ pub async fn thread_main(
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
stream: impl AsyncRead + AsyncWrite + Unpin,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
@@ -148,6 +148,8 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
.or_else(|e| stream.throw_error(e))
.await?;
// TODO: set creds.cluster here when SNI info is available
break Ok(Some((stream, creds)));
}
CancelRequest(cancel_key_data) => {
@@ -174,7 +176,7 @@ impl<S> Client<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,

View File

@@ -38,6 +38,7 @@ impl ServerSecret {
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.
#[allow(dead_code)]
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
// Refer to `auth-scram.c : scram_mock_salt`.
let mocked_salt = super::sha256([user.as_bytes(), nonce]);

View File

@@ -1382,8 +1382,8 @@ def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]:
class ZenithProxy(PgProtocol):
def __init__(self, port: int):
super().__init__(host="127.0.0.1",
user="pytest",
password="pytest",
user="proxy_user",
password="pytest2",
port=port,
dbname='postgres')
self.http_port = 7001
@@ -1399,8 +1399,8 @@ class ZenithProxy(PgProtocol):
args = [bin_proxy]
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.port}"])
args.extend(["--auth-method", "password"])
args.extend(["--static-router", addr])
args.extend(["--auth-backend", "postgres"])
args.extend(["--auth-endpoint", "postgres://proxy_auth:pytest1@localhost:5432/postgres"])
self._popen = subprocess.Popen(args)
self._wait_until_ready()
@@ -1422,7 +1422,8 @@ class ZenithProxy(PgProtocol):
def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
"""Zenith proxy that routes directly to vanilla postgres."""
vanilla_pg.start()
vanilla_pg.safe_psql("create user pytest with password 'pytest';")
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
with ZenithProxy(4432) as proxy:
proxy.start_static()