Add static router

This commit is contained in:
Bojan Serafimov
2022-02-19 01:09:00 -05:00
committed by Dmitry Ivanov
parent cca886682b
commit 65a0b2736b
3 changed files with 73 additions and 7 deletions

View File

@@ -1,5 +1,5 @@
use crate::compute::DatabaseInfo;
use crate::config::{ClientAuthMethod, ProxyConfig};
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::stream::PqStream;
use anyhow::{anyhow, bail, Context};
@@ -38,16 +38,19 @@ impl ClientCredentials {
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
let db_info = match config.client_auth_method {
ClientAuthMethod::Mixed => {
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
let db_info = match &config.router_config {
Static { host, port } => handle_static(host.clone(), port.clone(), client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
handle_existing_user(config, client, self).await
} else {
handle_new_user(config, client).await
}
}
ClientAuthMethod::Password => handle_existing_user(config, client, self).await,
ClientAuthMethod::Link => handle_new_user(config, client).await,
Dynamic(Password) => handle_existing_user(config, client, self).await,
Dynamic(Link) => handle_new_user(config, client).await,
};
db_info.context("failed to authenticate client")
@@ -58,6 +61,39 @@ fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
async fn handle_static(
host: String,
port: u16,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
client
.write_message(&Be::AuthenticationCleartextPassword)
.await?;
// Read client's password bytes
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
let db_info = DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_existing_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,

View File

@@ -15,6 +15,11 @@ pub enum ClientAuthMethod {
Mixed,
}
pub enum RouterConfig {
Static { host: String, port: u16 },
Dynamic(ClientAuthMethod),
}
impl FromStr for ClientAuthMethod {
type Err = anyhow::Error;
@@ -34,7 +39,7 @@ pub struct ProxyConfig {
pub proxy_address: SocketAddr,
/// method of assigning compute nodes
pub client_auth_method: ClientAuthMethod,
pub router_config: RouterConfig,
/// internally used for status and prometheus metrics
pub http_address: SocketAddr,

View File

@@ -13,6 +13,8 @@ use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use crate::config::{ClientAuthMethod, RouterConfig};
mod auth;
mod cancellation;
mod compute;
@@ -51,6 +53,13 @@ async fn main() -> anyhow::Result<()> {
.help("Possible values: password | link | mixed")
.default_value("mixed"),
)
.arg(
Arg::new("static-router")
.short('s')
.long("static-router")
.takes_value(true)
.help("Route all clients to host:port"),
)
.arg(
Arg::new("mgmt")
.short('m')
@@ -108,9 +117,25 @@ async fn main() -> anyhow::Result<()> {
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
};
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
let router_config = match arg_matches.value_of("static-router") {
None => RouterConfig::Dynamic(auth_method),
Some(addr) => {
if let ClientAuthMethod::Password = auth_method {
let (host, port) = addr.split_once(":").unwrap();
RouterConfig::Static {
host: host.to_string(),
port: port.parse().unwrap(),
}
} else {
bail!("static-router requires --auth-method password")
}
}
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
router_config,
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
client_auth_method: arg_matches.value_of("auth-method").unwrap().parse()?,
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
http_address: arg_matches.value_of("http").unwrap().parse()?,
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,