[proxy] Introduce cloud::Api for communication with Neon Cloud

* `cloud::legacy` talks to Cloud API V1.
* `cloud::api` defines Cloud API v2.
* `cloud::local` mocks the Cloud API V2 using a local postgres instance.
* It's possible to choose between API versions using the `--api-version` flag.
This commit is contained in:
Dmitry Ivanov
2022-04-27 13:34:59 +03:00
committed by Stas Kelvich
parent 9df8915b03
commit af0195b604
15 changed files with 471 additions and 300 deletions

View File

@@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
anyhow = "1.0"
async-trait = "0.1"
base64 = "0.13.0"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
@@ -37,7 +38,6 @@ metrics = { path = "../libs/metrics" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
[dev-dependencies]
async-trait = "0.1"
rcgen = "0.8.14"
rstest = "0.12"
tokio-postgres-rustls = "0.9.0"

View File

@@ -1,22 +1,16 @@
mod credentials;
#[cfg(test)]
mod flow;
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::config::{CloudApi, ProxyConfig};
use crate::error::UserFacingError;
use crate::stream::PqStream;
use crate::waiters;
use crate::{cloud, compute, waiters};
use std::io;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub use credentials::ClientCredentials;
#[cfg(test)]
pub use flow::*;
/// Common authentication error.
@@ -24,9 +18,14 @@ pub use flow::*;
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error(transparent)]
Console(#[from] cplane_api::AuthError),
Console(#[from] cloud::AuthError),
#[error(transparent)]
GetAuthInfo(#[from] cloud::api::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] cloud::api::WakeComputeError),
#[cfg(test)]
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -41,19 +40,19 @@ pub enum AuthErrorImpl {
impl AuthErrorImpl {
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::auth_failed(msg))
AuthErrorImpl::Console(cloud::AuthError::auth_failed(msg))
}
}
impl From<waiters::RegisterError> for AuthErrorImpl {
fn from(e: waiters::RegisterError) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::from(e))
AuthErrorImpl::Console(cloud::AuthError::from(e))
}
}
impl From<waiters::WaitError> for AuthErrorImpl {
fn from(e: waiters::WaitError) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::from(e))
AuthErrorImpl::Console(cloud::AuthError::from(e))
}
}
@@ -81,40 +80,28 @@ impl UserFacingError for AuthError {
}
}
async fn handle_static(
host: String,
port: u16,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<DatabaseInfo, AuthError> {
client
.write_message(&Be::AuthenticationCleartextPassword)
.await?;
// Read client's password bytes
let msg = client.read_password_message().await?;
let cleartext_password = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let db_info = DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_existing_user(
async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<DatabaseInfo, AuthError> {
) -> Result<compute::NodeInfo, AuthError> {
if creds.is_existing_user() {
match &config.cloud_endpoint {
CloudApi::V1(api) => handle_existing_user_v1(api, client, creds).await,
CloudApi::V2(api) => handle_existing_user_v2(api.as_ref(), client, creds).await,
}
} else {
let redirect_uri = config.redirect_uri.as_ref();
handle_new_user(redirect_uri, client).await
}
}
/// Authenticate user via a legacy cloud API endpoint.
async fn handle_existing_user_v1(
cloud: &cloud::Legacy,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let md5_salt = rand::random();
@@ -126,8 +113,7 @@ async fn handle_existing_user(
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let cplane = CPlaneApi::new(config.auth_endpoint.clone());
let db_info = cplane
let db_info = cloud
.authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id)
.await?;
@@ -135,17 +121,53 @@ async fn handle_existing_user(
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}
/// Authenticate user via a new cloud API endpoint which supports SCRAM.
async fn handle_existing_user_v2(
cloud: &(impl cloud::Api + ?Sized),
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
let auth_info = cloud.get_auth_info(&creds).await?;
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
cloud::api::AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
cloud::api::AuthInfo::Scram(secret) => {
let scram = Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info: cloud.wake_compute(&creds).await?,
scram_keys,
})
}
async fn handle_new_user(
config: &ProxyConfig,
redirect_uri: &str,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<DatabaseInfo, AuthError> {
) -> Result<compute::NodeInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let greeting = hello_message(redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
let db_info = cloud::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
@@ -160,7 +182,10 @@ async fn handle_new_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
Ok(db_info)
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
})
}
fn new_psql_session_id() -> String {

View File

@@ -1,7 +1,7 @@
//! User credentials used in authentication.
use super::AuthError;
use crate::compute::DatabaseInfo;
use crate::compute;
use crate::config::ProxyConfig;
use crate::error::UserFacingError;
use crate::stream::PqStream;
@@ -18,12 +18,20 @@ pub enum ClientCredsParseError {
impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl ClientCredentials {
pub fn is_existing_user(&self) -> bool {
// This logic will likely change in the future.
self.user.ends_with("@zenith")
}
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = ClientCredsParseError;
@@ -47,20 +55,8 @@ impl ClientCredentials {
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<DatabaseInfo, AuthError> {
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
match &config.router_config {
Static { host, port } => super::handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
super::handle_existing_user(config, client, self).await
} else {
super::handle_new_user(config, client).await
}
}
Dynamic(Password) => super::handle_existing_user(config, client, self).await,
Dynamic(Link) => super::handle_new_user(config, client).await,
}
) -> Result<compute::NodeInfo, AuthError> {
// This method is just a convenient facade for `handle_user`
super::handle_user(config, client, self).await
}
}

View File

@@ -27,19 +27,6 @@ impl AuthMethod for Scram<'_> {
}
}
/// Use password-based auth in [`AuthFlow`].
pub struct Md5(
/// Salt for client.
pub [u8; 4],
);
impl AuthMethod for Md5 {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationMD5Password(self.0)
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
@@ -70,19 +57,10 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
/// Stream wrapper for handling simple MD5 password auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Md5> {
/// Perform user authentication. Raise an error in case authentication failed.
#[allow(unused)]
pub async fn authenticate(self) -> Result<(), AuthError> {
unimplemented!("MD5 auth flow is yet to be implemented");
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> Result<(), AuthError> {
pub async fn authenticate(self) -> Result<scram::ScramKey, AuthError> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
@@ -93,10 +71,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
let secret = self.state.0;
sasl::SaslStream::new(self.stream, sasl.message)
let key = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(secret, rand::random, None))
.await?;
Ok(())
Ok(key)
}
}

46
proxy/src/cloud.rs Normal file
View File

@@ -0,0 +1,46 @@
mod local;
mod legacy;
pub use legacy::{AuthError, AuthErrorImpl, Legacy};
pub mod api;
pub use api::{Api, BoxedApi};
use crate::mgmt;
use crate::waiters::{self, Waiter, Waiters};
use lazy_static::lazy_static;
lazy_static! {
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
}
/// Give caller an opportunity to wait for the cloud's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Construct a new opaque cloud API provider.
pub fn new(url: reqwest::Url) -> anyhow::Result<BoxedApi> {
Ok(match url.scheme() {
"https" | "http" => {
todo!("build a real cloud wrapper")
}
"postgresql" | "postgres" | "pg" => {
// Just point to a local running postgres instance.
Box::new(local::Local { url })
}
other => anyhow::bail!("unsupported url scheme: {other}"),
})
}

120
proxy/src/cloud/api.rs Normal file
View File

@@ -0,0 +1,120 @@
//! Declaration of Cloud API V2.
use crate::{auth, scram};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum GetAuthInfoError {
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
BadSecret,
#[error("Bad client credentials: {0:?}")]
BadCredentials(crate::auth::ClientCredentials),
#[error(transparent)]
Io(#[from] std::io::Error),
}
// TODO: convert to an enum and describe possible sub-errors (see above)
#[derive(Debug, Error)]
#[error("Failed to wake up the compute node")]
pub struct WakeComputeError;
/// Opaque implementation of Cloud API.
pub type BoxedApi = Box<dyn Api + Send + Sync>;
/// Cloud API methods required by the proxy.
#[async_trait]
pub trait Api {
/// Get authentication information for the given user.
async fn get_auth_info(
&self,
creds: &auth::ClientCredentials,
) -> Result<AuthInfo, GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
creds: &auth::ClientCredentials,
) -> Result<DatabaseInfo, WakeComputeError>;
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
/// Compute node connection params provided by the cloud.
/// Note how it implements serde traits, since we receive it over the wire.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// [Cloud API V1](super::legacy) returns cleartext password,
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
/// authentication, so we can leverage this method and cope without password.
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -1,42 +1,19 @@
//! Cloud API V1.
use super::api::DatabaseInfo;
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::error::UserFacingError;
use crate::mgmt;
use crate::waiters::{self, Waiter, Waiters};
use lazy_static::lazy_static;
use crate::waiters;
use serde::{Deserialize, Serialize};
use thiserror::Error;
lazy_static! {
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(
psql_session_id: &str,
msg: Result<DatabaseInfo, String>,
) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Zenith console API wrapper.
pub struct CPlaneApi {
/// Neon cloud API provider.
pub struct Legacy {
auth_endpoint: reqwest::Url,
}
impl CPlaneApi {
impl Legacy {
/// Construct a new legacy cloud API provider.
pub fn new(auth_endpoint: reqwest::Url) -> Self {
Self { auth_endpoint }
}
@@ -95,7 +72,17 @@ impl UserFacingError for AuthError {
}
}
impl CPlaneApi {
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl Legacy {
pub async fn authenticate_proxy_client(
&self,
creds: ClientCredentials,
@@ -111,8 +98,8 @@ impl CPlaneApi {
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
with_waiter(psql_session_id, |waiter| async {
println!("cplane request: {}", url);
super::with_waiter(psql_session_id, |waiter| async {
println!("cloud request: {}", url);
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
@@ -135,16 +122,6 @@ impl CPlaneApi {
}
}
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
#[cfg(test)]
mod tests {
use super::*;

76
proxy/src/cloud/local.rs Normal file
View File

@@ -0,0 +1,76 @@
//! Local mock of Cloud API V2.
use super::api::{self, Api, AuthInfo, DatabaseInfo};
use crate::auth::ClientCredentials;
use crate::scram;
use async_trait::async_trait;
/// Mocked cloud for testing purposes.
pub struct Local {
/// Database url, e.g. `postgres://user:password@localhost:5432/database`.
pub url: reqwest::Url,
}
#[async_trait]
impl Api for Local {
async fn get_auth_info(
&self,
creds: &ClientCredentials,
) -> Result<AuthInfo, api::GetAuthInfoError> {
// We wrap `tokio_postgres::Error` because we don't want to infect the
// method's error type with a detail that's specific to debug mode only.
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.url.as_str(), tokio_postgres::NoTls)
.await
.map_err(io_error)?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client
.query(query, &[&creds.user])
.await
.map_err(io_error)?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(api::GetAuthInfoError::BadCredentials(creds.to_owned())),
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(api::GetAuthInfoError::BadSecret)
}
}
}
async fn wake_compute(
&self,
creds: &ClientCredentials,
) -> Result<DatabaseInfo, api::WakeComputeError> {
// Local setup doesn't have a dedicated compute node,
// so we just return the local database we're pointed at.
Ok(DatabaseInfo {
host: self.url.host_str().unwrap_or("localhost").to_owned(),
port: self.url.port().unwrap_or(5432),
dbname: creds.dbname.to_owned(),
user: creds.user.to_owned(),
password: None,
})
}
}

View File

@@ -1,6 +1,6 @@
use crate::cancellation::CancelClosure;
use crate::cloud::api::DatabaseInfo;
use crate::error::UserFacingError;
use serde::{Deserialize, Serialize};
use std::io;
use std::net::SocketAddr;
use thiserror::Error;
@@ -23,32 +23,21 @@ pub enum ConnectionError {
impl UserFacingError for ConnectionError {}
/// Compute node connection params.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
/// PostgreSQL version as [`String`].
pub type Version = String;
impl DatabaseInfo {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// Compute node connection params.
pub struct NodeInfo {
pub db_info: DatabaseInfo,
pub scram_keys: Option<ScramKeys>,
}
impl NodeInfo {
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
let host_port = format!("{}:{}", self.host, self.port);
let host_port = format!("{}:{}", self.db_info.host, self.db_info.port);
let socket = TcpStream::connect(host_port).await?;
let socket_addr = socket.peer_addr()?;
socket2::SockRef::from(&socket).set_keepalive(true)?;
@@ -63,11 +52,13 @@ impl DatabaseInfo {
.await
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
// TODO: establish a secure connection to the DB
let (client, conn) = tokio_postgres::Config::from(self)
.connect_raw(&mut socket, NoTls)
.await?;
let mut config = tokio_postgres::Config::from(self.db_info);
if let Some(scram_keys) = self.scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys));
}
// TODO: establish a secure connection to the DB
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
let version = conn
.parameter("server_version")
.ok_or(ConnectionError::FailedToFetchPgVersion)?
@@ -78,21 +69,3 @@ impl DatabaseInfo {
Ok((socket, version, cancel_closure))
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}

View File

@@ -1,65 +1,43 @@
use crate::cloud;
use anyhow::{bail, ensure, Context};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
pub type TlsConfig = Arc<rustls::ServerConfig>;
#[non_exhaustive]
pub enum ClientAuthMethod {
Password,
Link,
/// Use password auth only if username ends with "@zenith"
Mixed,
}
pub enum RouterConfig {
Static { host: String, port: u16 },
Dynamic(ClientAuthMethod),
}
impl FromStr for ClientAuthMethod {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
use ClientAuthMethod::*;
match s {
"password" => Ok(Password),
"link" => Ok(Link),
"mixed" => Ok(Mixed),
_ => bail!("Invalid option for router: `{}`", s),
}
}
}
pub struct ProxyConfig {
/// main entrypoint for users to connect to
pub proxy_address: SocketAddr,
/// Unauthenticated users will be redirected to this URL.
pub redirect_uri: reqwest::Url,
/// method of assigning compute nodes
pub router_config: RouterConfig,
/// internally used for status and prometheus metrics
pub http_address: SocketAddr,
/// management endpoint. Upon user account creation control plane
/// will notify us here, so that we can 'unfreeze' user session.
/// TODO It uses postgres protocol over TCP but should be migrated to http.
pub mgmt_address: SocketAddr,
/// send unauthenticated users to this URI
pub redirect_uri: String,
/// control plane address where we would check auth.
pub auth_endpoint: reqwest::Url,
/// Cloud API endpoint for user authentication.
pub cloud_endpoint: CloudApi,
/// TLS configuration for the proxy.
pub tls_config: Option<TlsConfig>,
}
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
/// Cloud API configuration.
pub enum CloudApi {
/// We'll drop this one when [`CloudApi::V2`] is stable.
V1(crate::cloud::Legacy),
/// The new version of the cloud API.
V2(crate::cloud::BoxedApi),
}
impl CloudApi {
/// Configure Cloud API provider.
pub fn new(version: &str, url: reqwest::Url) -> anyhow::Result<Self> {
Ok(match version {
"v1" => Self::V1(cloud::Legacy::new(url)),
"v2" => Self::V2(cloud::new(url)?),
_ => bail!("unknown cloud API version: {}", version),
})
}
}
pub type TlsConfig = Arc<rustls::ServerConfig>;
/// Configure TLS for the main endpoint.
pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
let key = {
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.context("couldn't read TLS keys")?;
@@ -68,7 +46,7 @@ pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
};
let cert_chain = {
let cert_chain_bytes = std::fs::read(cert_path).context("SSL cert file")?;
let cert_chain_bytes = std::fs::read(cert_path).context("TLS cert file")?;
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context("couldn't read TLS certificate chain")?
.into_iter()

View File

@@ -6,34 +6,27 @@
mod auth;
mod cancellation;
mod cloud;
mod compute;
mod config;
mod cplane_api;
mod error;
mod http;
mod mgmt;
mod parse;
mod proxy;
mod sasl;
mod scram;
mod stream;
mod waiters;
// Currently SCRAM is only used in tests
#[cfg(test)]
mod parse;
#[cfg(test)]
mod sasl;
#[cfg(test)]
mod scram;
use anyhow::{bail, Context};
use clap::{App, Arg};
use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use std::{future::Future, net::SocketAddr};
use tokio::{net::TcpListener, task::JoinError};
use utils::GIT_VERSION;
use crate::config::{ClientAuthMethod, RouterConfig};
/// Flattens `Result<Result<T>>` into `Result<T>`.
async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
@@ -44,7 +37,7 @@ async fn flatten_err(
#[tokio::main]
async fn main() -> anyhow::Result<()> {
metrics::set_common_metrics_prefix("zenith_proxy");
let arg_matches = App::new("Zenith proxy/router")
let arg_matches = App::new("Neon proxy/router")
.version(GIT_VERSION)
.arg(
Arg::new("proxy")
@@ -97,77 +90,80 @@ async fn main() -> anyhow::Result<()> {
.short('a')
.long("auth-endpoint")
.takes_value(true)
.help("API endpoint for authenticating users")
.help("cloud API endpoint for authenticating users")
.default_value("http://localhost:3000/authenticate_proxy_request/"),
)
.arg(
Arg::new("ssl-key")
.short('k')
.long("ssl-key")
Arg::new("api-version")
.long("api-version")
.takes_value(true)
.help("path to SSL key for client postgres connections"),
.default_value("v1")
.possible_values(["v1", "v2"])
.help("cloud API version to be used for authentication"),
)
.arg(
Arg::new("ssl-cert")
.short('c')
.long("ssl-cert")
Arg::new("tls-key")
.short('k')
.long("tls-key")
.alias("ssl-key") // backwards compatibility
.takes_value(true)
.help("path to SSL cert for client postgres connections"),
.help("path to TLS key for client postgres connections"),
)
.arg(
Arg::new("tls-cert")
.short('c')
.long("tls-cert")
.alias("ssl-cert") // backwards compatibility
.takes_value(true)
.help("path to TLS cert for client postgres connections"),
)
.get_matches();
let tls_config = match (
arg_matches.value_of("ssl-key"),
arg_matches.value_of("ssl-cert"),
arg_matches.value_of("tls-key"),
arg_matches.value_of("tls-cert"),
) {
(Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?),
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
(None, None) => None,
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
let router_config = match arg_matches.value_of("static-router") {
None => RouterConfig::Dynamic(auth_method),
Some(addr) => {
if let ClientAuthMethod::Password = auth_method {
let (host, port) = addr.split_once(':').unwrap();
RouterConfig::Static {
host: host.to_string(),
port: port.parse().unwrap(),
}
} else {
bail!("static-router requires --auth-method password")
}
}
};
let proxy_address: SocketAddr = arg_matches.value_of("proxy").unwrap().parse()?;
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
let cloud_endpoint = config::CloudApi::new(
arg_matches.value_of("api-version").unwrap(),
arg_matches.value_of("auth-endpoint").unwrap().parse()?,
)?;
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
router_config,
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
http_address: arg_matches.value_of("http").unwrap().parse()?,
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
cloud_endpoint,
tls_config,
}));
println!("Version: {}", GIT_VERSION);
// Check that we can bind to address before further initialization
println!("Starting http on {}", config.http_address);
let http_listener = TcpListener::bind(config.http_address).await?.into_std()?;
println!("Starting http on {}", http_address);
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
println!("Starting mgmt on {}", config.mgmt_address);
let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?;
println!("Starting mgmt on {}", mgmt_address);
let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?;
println!("Starting proxy on {}", config.proxy_address);
let proxy_listener = TcpListener::bind(config.proxy_address).await?;
println!("Starting proxy on {}", proxy_address);
let proxy_listener = TcpListener::bind(proxy_address).await?;
let http = tokio::spawn(http::thread_main(http_listener));
let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener));
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener));
let tasks = [
tokio::spawn(http::thread_main(http_listener)),
tokio::spawn(proxy::thread_main(config, proxy_listener)),
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
]
.map(flatten_err);
let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)];
// This will block until all tasks have completed.
// Furthermore, the first one to fail will cancel the rest.
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
Ok(())

View File

@@ -1,4 +1,4 @@
use crate::{compute::DatabaseInfo, cplane_api};
use crate::cloud;
use anyhow::Context;
use serde::Deserialize;
use std::{
@@ -75,12 +75,12 @@ struct PsqlSessionResponse {
#[derive(Deserialize)]
enum PsqlSessionResult {
Success(DatabaseInfo),
Success(cloud::api::DatabaseInfo),
Failure(String),
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<DatabaseInfo, String>;
pub type ComputeReady = Result<cloud::api::DatabaseInfo, String>;
impl PsqlSessionResult {
fn into_compute_ready(self) -> ComputeReady {
@@ -111,7 +111,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
match cplane_api::notify(&resp.session_id, resp.result.into_compute_ready()) {
match cloud::notify(&resp.session_id, resp.result.into_compute_ready()) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -185,10 +185,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
// Authenticate and connect to a compute node.
let auth = creds.authenticate(config, &mut stream).await;
let db_info = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let (db, version, cancel_closure) =
db_info.connect().or_else(|e| stream.throw_error(e)).await?;
node.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
stream

View File

@@ -9,10 +9,12 @@
mod exchange;
mod key;
mod messages;
mod password;
mod secret;
mod signature;
#[cfg(test)]
mod password;
pub use exchange::Exchange;
pub use key::ScramKey;
pub use secret::ServerSecret;

View File

@@ -16,6 +16,10 @@ impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
}
pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {