do an actual proxy pass

This commit is contained in:
Stas Kelvich
2021-06-24 23:32:25 +03:00
parent 6f747893be
commit 605b90c6c7
7 changed files with 354 additions and 119 deletions

View File

@@ -44,11 +44,11 @@ impl CPlaneApi {
}
}
fn get_database_uri(_user: String, _database: String) -> Option<String> {
Some("postgresql://localhost/stas".to_string())
pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result<String> {
Ok("user=stas dbname=stas".to_string())
}
fn create_database(_user: String, _database: String) -> Option<String> {
Some("postgresql://localhost/stas".to_string())
pub fn create_database(&self, _user: &String, _database: &String) -> Result<String> {
Ok("user=stas dbname=stas".to_string())
}
}

View File

@@ -11,6 +11,7 @@ use std::{
};
mod cplane_api;
mod mgmt;
mod proxy;
pub struct ProxyConf {
@@ -37,11 +38,28 @@ fn main() -> anyhow::Result<()> {
println!("Starting proxy on {}", conf.proxy_address);
let pageserver_listener = TcpListener::bind(conf.proxy_address)?;
println!("Starting mgmt on {}", conf.mgmt_address);
let mgmt_listener = TcpListener::bind(conf.mgmt_address)?;
let mut threads = Vec::new();
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
let proxy_listener_thread = thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?;
threads.push(
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?,
);
proxy_listener_thread.join().unwrap()
threads.push(
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(&conf, mgmt_listener))?,
);
for t in threads {
let _ = t.join().unwrap();
}
Ok(())
}

View File

@@ -1,22 +1,24 @@
use crate::{cplane_api::CPlaneApi, ProxyConf};
use bytes::Bytes;
use std::{
net::{TcpListener, TcpStream},
thread,
};
use zenith_utils::postgres_backend::{AuthType, PostgresBackend};
use zenith_utils::{
postgres_backend,
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};
use anyhow::bail;
use tokio_postgres::NoTls;
use rand::Rng;
use std::thread;
use tokio::io::AsyncWriteExt;
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
use zenith_utils::pq_proto::*;
use zenith_utils::{postgres_backend, pq_proto::BeMessage};
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> {
pub fn thread_main(
conf: &'static ProxyConf,
listener: std::net::TcpListener,
) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
@@ -30,92 +32,211 @@ pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::R
}
}
pub fn proxy_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = ProxyHandler {
conf,
existing_user: false,
cplane: CPlaneApi::new(&conf.cplane_address),
user: "".into(),
database: "".into(),
};
let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::Trust)?;
pgbackend.run(&mut conn_handler)
}
struct ProxyHandler {
struct ProxyConnection {
conf: &'static ProxyConf,
existing_user: bool,
cplane: CPlaneApi,
user: String,
database: String,
pgb: PostgresBackend,
md5_salt: [u8; 4],
}
// impl ProxyHandler {
// }
pub fn proxy_conn_main(
conf: &'static ProxyConf,
socket: std::net::TcpStream,
) -> anyhow::Result<()> {
let mut conn = ProxyConnection {
conf,
existing_user: false,
cplane: CPlaneApi::new(&conf.cplane_address),
user: "".into(),
database: "".into(),
pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?,
md5_salt: [0u8; 4],
};
impl postgres_backend::Handler for ProxyHandler {
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: Bytes,
) -> anyhow::Result<()> {
println!("Got query: {:?}", query_string);
// Check StartupMessage
// This will set conn.existing_user and we can decide on next actions
conn.handle_startup()?;
if !self.existing_user {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"new user scenario")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"existing user")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
// both scenarious here should end up producing database connection string
let database_uri = if conn.existing_user {
conn.handle_existing_user()?
} else {
conn.handle_new_user()?
};
// ok, proxy pass user connection to database_uri
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _ = runtime.block_on(proxy_pass(
conn.pgb,
"127.0.0.1:5432".to_string(),
database_uri,
))?;
println!("proxy_conn_main done;");
Ok(())
}
impl ProxyConnection {
fn handle_startup(&mut self) -> anyhow::Result<()> {
loop {
let msg = self.pgb.read_message()?;
println!("got message {:?}", msg);
match msg {
Some(FeMessage::StartupMessage(m)) => {
println!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
println!("SSL requested");
self.pgb.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
self.user = m
.params
.get("user")
.ok_or_else(|| {
anyhow::Error::msg("user is required in startup packet")
})?
.into();
self.database = m
.params
.get("database")
.ok_or_else(|| {
anyhow::Error::msg("database is required in startup packet")
})?
.into();
self.existing_user = self.user.ends_with("@zenith");
break;
}
StartupRequestCode::Cancel => break,
}
}
None => {
bail!("connection closed")
}
unexpected => {
bail!("unexpected message type : {:?}", unexpected)
}
}
}
pgb.flush()?;
Ok(())
}
fn startup(
&mut self,
pgb: &mut PostgresBackend,
sm: &zenith_utils::pq_proto::FeStartupMessage,
) -> anyhow::Result<()> {
println!("Got startup: {:?}", sm);
fn handle_existing_user(&mut self) -> anyhow::Result<String> {
// ask password
rand::thread_rng().fill(&mut self.md5_salt);
self.pgb
.write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?;
self.pgb.state = ProtoState::Authentication;
self.user = sm
.params
.get("user")
.ok_or_else(|| anyhow::Error::msg("user is required in startup packet"))?
.into();
self.database = sm
.params
.get("database")
.ok_or_else(|| anyhow::Error::msg("database is required in startup packet"))?
.into();
// check password
println!("handle_existing_user");
let msg = self.pgb.read_message()?;
println!("got message {:?}", msg);
if let Some(FeMessage::PasswordMessage(m)) = msg {
println!("got password message '{:?}'", m);
// We use '@zenith' in username as an indicator that user already created
// this database and not logging in with his system username.
//
// With that approach we can create new databases on demand with something like
// psql -h zenith.tech -U stas@zenith my_new_db (assuming .pgpass is set). That is
// especially helpful if one is setting configuration files for some app that requires
// database -- he can just fill config and run initial migration without any other actions.
if self.user.ends_with("@zenith") {
pgb.auth_type = AuthType::MD5;
self.existing_user = true;
assert!(self.existing_user);
let (_trailing_null, md5_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("unexpected password message"))?;
if let Err(e) = self.check_auth_md5(md5_response) {
self.pgb
.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
bail!("auth failed: {}", e);
} else {
self.pgb
.write_message_noflush(&BeMessage::AuthenticationOk)?;
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus)?;
self.pgb.write_message(&BeMessage::ReadyForQuery)?;
}
}
Ok(())
// ok, we are authorized
self.cplane.get_database_uri(&self.user, &self.database)
}
fn check_auth_md5(
&mut self,
pgb: &mut PostgresBackend,
md5_response: &[u8],
) -> anyhow::Result<()> {
fn handle_new_user(&mut self) -> anyhow::Result<String> {
let mut reg_id_buf = [0u8; 8];
rand::thread_rng().fill(&mut reg_id_buf);
let reg_id = hex::encode(reg_id_buf);
let hello_message = format!("☀️ Welcome to Zenith!
To proceed with database creation open following link:
https://console.zenith.tech/claim_db/{}
It needed to be done once and we will send you '.pgpass' file which will allow you to access or create
databases without opening the browser.
", reg_id);
self.pgb
.write_message_noflush(&BeMessage::AuthenticationOk)?;
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus)?;
self.pgb
.write_message(&BeMessage::NoticeResponse(hello_message.to_string()))?;
// await for database creation
let connstring = self.cplane.get_database_uri(&self.user, &self.database)?;
self.pgb.write_message(&BeMessage::ReadyForQuery)?;
Ok(connstring)
}
fn check_auth_md5(&self, md5_response: &[u8]) -> anyhow::Result<()> {
assert!(self.existing_user);
self.cplane
.check_auth(self.user.as_str(), md5_response, &pgb.md5_salt)
.check_auth(self.user.as_str(), md5_response, &self.md5_salt)
}
}
async fn proxy_pass(
pgb: PostgresBackend,
proxy_addr: String,
connstr: String,
) -> anyhow::Result<()> {
let mut socket = tokio::net::TcpStream::connect(proxy_addr).await?;
let config = connstr.parse::<tokio_postgres::Config>()?;
let _ = config.connect_raw(&mut socket, NoTls).await?;
println!("Connected to pg, proxying");
let incoming_std = pgb.into_stream();
incoming_std.set_nonblocking(true)?;
let mut incoming_conn = tokio::net::TcpStream::from_std(incoming_std)?;
let (mut ri, mut wi) = incoming_conn.split();
let (mut ro, mut wo) = socket.split();
let client_to_server = async {
tokio::io::copy(&mut ri, &mut wo).await?;
wo.shutdown().await
};
let server_to_client = async {
tokio::io::copy(&mut ro, &mut wi).await?;
wi.shutdown().await
};
tokio::try_join!(client_to_server, server_to_client)?;
Ok(())
}