diff --git a/Cargo.lock b/Cargo.lock index f0c3242e46..4736d7ce24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1355,6 +1355,8 @@ dependencies = [ "hex", "md5", "rand", + "serde", + "serde_json", "tokio", "tokio-postgres 0.7.2", "zenith_utils", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index b6416e2586..39ad50ad9b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -12,7 +12,8 @@ bytes = { version = "1.0.1", features = ['serde'] } md5 = "0.7.0" rand = "0.8.3" hex = "0.4.3" - +serde = "1" +serde_json = "1" tokio = "1.7.1" tokio-postgres = "0.7.2" diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 03e5c2dd53..e8be84d8ab 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -1,10 +1,17 @@ use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, net::SocketAddr}; pub struct CPlaneApi { address: SocketAddr, } +#[derive(Serialize, Deserialize)] +pub struct DatabaseInfo { + pub addr: SocketAddr, + pub connstr: String, +} + // mock cplane api impl CPlaneApi { pub fn new(address: &SocketAddr) -> CPlaneApi { @@ -44,11 +51,17 @@ impl CPlaneApi { } } - pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result { - Ok("user=stas dbname=stas".to_string()) + pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result { + Ok(DatabaseInfo { + addr: "127.0.0.1:5432".parse()?, + connstr: "user=stas dbname=stas".into(), + }) } - pub fn create_database(&self, _user: &String, _database: &String) -> Result { - Ok("user=stas dbname=stas".to_string()) + pub fn create_database(&self, _user: &String, _database: &String) -> Result { + Ok(DatabaseInfo { + addr: "127.0.0.1:5432".parse()?, + connstr: "user=stas dbname=stas".into(), + }) } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index dc013d9f92..32bb759e80 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -6,10 +6,14 @@ /// in somewhat transparent manner (again via communication with control plane API). /// use std::{ + collections::HashMap, net::{SocketAddr, TcpListener}, + sync::{mpsc, Mutex}, thread, }; +use cplane_api::DatabaseInfo; + mod cplane_api; mod mgmt; mod proxy; @@ -26,20 +30,29 @@ pub struct ProxyConf { pub cplane_address: SocketAddr, } +pub struct ProxyState { + pub conf: ProxyConf, + pub waiters: Mutex>>>, +} + fn main() -> anyhow::Result<()> { let conf = ProxyConf { proxy_address: "0.0.0.0:4000".parse()?, mgmt_address: "0.0.0.0:8080".parse()?, cplane_address: "127.0.0.1:3000".parse()?, }; - let conf: &'static ProxyConf = Box::leak(Box::new(conf)); + let state = ProxyState { + conf, + waiters: Mutex::new(HashMap::new()), + }; + let state: &'static ProxyState = Box::leak(Box::new(state)); // Check that we can bind to address before further initialization - println!("Starting proxy on {}", conf.proxy_address); - let pageserver_listener = TcpListener::bind(conf.proxy_address)?; + println!("Starting proxy on {}", state.conf.proxy_address); + let pageserver_listener = TcpListener::bind(state.conf.proxy_address)?; - println!("Starting mgmt on {}", conf.mgmt_address); - let mgmt_listener = TcpListener::bind(conf.mgmt_address)?; + println!("Starting mgmt on {}", state.conf.mgmt_address); + let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?; let mut threads = Vec::new(); @@ -48,13 +61,13 @@ fn main() -> anyhow::Result<()> { threads.push( thread::Builder::new() .name("Proxy thread".into()) - .spawn(move || proxy::thread_main(&conf, pageserver_listener))?, + .spawn(move || proxy::thread_main(&state, pageserver_listener))?, ); threads.push( thread::Builder::new() .name("Mgmt thread".into()) - .spawn(move || mgmt::thread_main(&conf, mgmt_listener))?, + .spawn(move || mgmt::thread_main(&state, mgmt_listener))?, ); for t in threads { diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index a0834c832d..87fd403d91 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -3,41 +3,70 @@ use std::{ thread, }; +use anyhow::bail; use bytes::Bytes; +use serde::{Deserialize, Serialize}; use zenith_utils::{ postgres_backend::{self, PostgresBackend}, pq_proto::{BeMessage, SINGLE_COL_ROWDESC}, }; -use crate::ProxyConf; +use crate::{cplane_api::DatabaseInfo, ProxyState}; /// /// Main proxy listener loop. /// /// Listens for connections, and launches a new handler thread for each. /// -pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> { +pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> { loop { let (socket, peer_addr) = listener.accept()?; println!("accepted connection from {}", peer_addr); socket.set_nodelay(true).unwrap(); thread::spawn(move || { - if let Err(err) = mgmt_conn_main(conf, socket) { + if let Err(err) = mgmt_conn_main(state, socket) { println!("error: {}", err); } }); } } -pub fn mgmt_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> { - let mut conn_handler = MgmtHandler { conf }; +pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> { + let mut conn_handler = MgmtHandler { state }; let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::Trust)?; pgbackend.run(&mut conn_handler) } struct MgmtHandler { - conf: &'static ProxyConf, + state: &'static ProxyState, +} +/// Serialized examples: +// { +// "session_id": "71d6d03e6d93d99a", +// "result": { +// "Success": { +// "addr": "127.0.0.1:5432", +// "connstr": "user=stas dbname=stas" +// } +// } +// } +// { +// "session_id": "71d6d03e6d93d99a", +// "result": { +// "Failure": "oops" +// } +// } +#[derive(Serialize, Deserialize)] +pub struct PsqlSessionResponse { + session_id: String, + result: PsqlSessionResult, +} + +#[derive(Serialize, Deserialize)] +pub enum PsqlSessionResult { + Success(DatabaseInfo), + Failure(String), } impl postgres_backend::Handler for MgmtHandler { @@ -46,13 +75,40 @@ impl postgres_backend::Handler for MgmtHandler { pgb: &mut PostgresBackend, query_string: Bytes, ) -> anyhow::Result<()> { + let (_, query_string) = query_string + .split_last() + .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; + + let (_, query_string) = query_string + .split_last() + .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; + println!("Got mgmt query: {:?}", query_string); - pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? - .write_message_noflush(&BeMessage::DataRow(&[Some(b"mgmt_ok")]))? - .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + let resp: PsqlSessionResponse = serde_json::from_slice(&query_string)?; - pgb.flush()?; - Ok(()) + let waiters = self.state.waiters.lock().unwrap(); + + let sender = waiters + .get(&resp.session_id) + .ok_or_else(|| anyhow::Error::msg("psql session_id is not found"))?; + + match resp.result { + PsqlSessionResult::Success(db_info) => { + sender.send(Ok(db_info))?; + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.flush()?; + Ok(()) + } + + PsqlSessionResult::Failure(message) => { + sender.send(Err(anyhow::Error::msg(message.clone())))?; + + bail!("psql session request failed: {}", message) + } + } } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 74f32fe4c0..750c02406f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,9 +1,12 @@ -use crate::{cplane_api::CPlaneApi, ProxyConf}; +use crate::cplane_api::CPlaneApi; +use crate::cplane_api::DatabaseInfo; +use crate::ProxyState; use anyhow::bail; use tokio_postgres::NoTls; use rand::Rng; +use std::sync::mpsc::channel; use std::thread; use tokio::io::AsyncWriteExt; use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; @@ -16,7 +19,7 @@ use zenith_utils::{postgres_backend, pq_proto::BeMessage}; /// Listens for connections, and launches a new handler thread for each. /// pub fn thread_main( - conf: &'static ProxyConf, + state: &'static ProxyState, listener: std::net::TcpListener, ) -> anyhow::Result<()> { loop { @@ -25,15 +28,17 @@ pub fn thread_main( socket.set_nodelay(true).unwrap(); thread::spawn(move || { - if let Err(err) = proxy_conn_main(conf, socket) { + if let Err(err) = proxy_conn_main(state, socket) { println!("error: {}", err); } }); } } +// XXX: clean up fields struct ProxyConnection { - conf: &'static ProxyConf, + state: &'static ProxyState, + existing_user: bool, cplane: CPlaneApi, @@ -42,20 +47,23 @@ struct ProxyConnection { pgb: PostgresBackend, md5_salt: [u8; 4], + + psql_session_id: String, } pub fn proxy_conn_main( - conf: &'static ProxyConf, + state: &'static ProxyState, socket: std::net::TcpStream, ) -> anyhow::Result<()> { let mut conn = ProxyConnection { - conf, + state, existing_user: false, - cplane: CPlaneApi::new(&conf.cplane_address), + cplane: CPlaneApi::new(&state.conf.cplane_address), user: "".into(), database: "".into(), pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?, md5_salt: [0u8; 4], + psql_session_id: "".into(), }; // Check StartupMessage @@ -63,7 +71,7 @@ pub fn proxy_conn_main( conn.handle_startup()?; // both scenarious here should end up producing database connection string - let database_uri = if conn.existing_user { + let db_info = if conn.existing_user { conn.handle_existing_user()? } else { conn.handle_new_user()? @@ -75,11 +83,7 @@ pub fn proxy_conn_main( .build() .unwrap(); - let _ = runtime.block_on(proxy_pass( - conn.pgb, - "127.0.0.1:5432".to_string(), - database_uri, - ))?; + let _ = runtime.block_on(proxy_pass(conn.pgb, db_info))?; println!("proxy_conn_main done;"); @@ -134,12 +138,12 @@ impl ProxyConnection { Ok(()) } - fn handle_existing_user(&mut self) -> anyhow::Result { + fn handle_existing_user(&mut self) -> anyhow::Result { // ask password rand::thread_rng().fill(&mut self.md5_salt); self.pgb .write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?; - self.pgb.state = ProtoState::Authentication; + self.pgb.state = ProtoState::Authentication; // XXX // check password println!("handle_existing_user"); @@ -171,21 +175,21 @@ impl ProxyConnection { self.cplane.get_database_uri(&self.user, &self.database) } - fn handle_new_user(&mut self) -> anyhow::Result { - let mut reg_id_buf = [0u8; 8]; - rand::thread_rng().fill(&mut reg_id_buf); - let reg_id = hex::encode(reg_id_buf); + fn handle_new_user(&mut self) -> anyhow::Result { + let mut psql_session_id_buf = [0u8; 8]; + rand::thread_rng().fill(&mut psql_session_id_buf); + self.psql_session_id = hex::encode(psql_session_id_buf); let hello_message = format!("☀️ Welcome to Zenith! To proceed with database creation open following link: - https://console.zenith.tech/claim_db/{} + https://console.zenith.tech/psql_session/{} It needed to be done once and we will send you '.pgpass' file which will allow you to access or create databases without opening the browser. -", reg_id); +", self.psql_session_id); self.pgb .write_message_noflush(&BeMessage::AuthenticationOk)?; @@ -195,10 +199,24 @@ databases without opening the browser. .write_message(&BeMessage::NoticeResponse(hello_message.to_string()))?; // await for database creation - let connstring = self.cplane.get_database_uri(&self.user, &self.database)?; + let (tx, rx) = channel::>(); + let _ = self + .state + .waiters + .lock() + .unwrap() + .insert(self.psql_session_id.clone(), tx); + + // Wait for web console response + // XXX: respond with error to client + let dbinfo = rx.recv()??; + + self.pgb.write_message_noflush(&BeMessage::NoticeResponse( + "Connecting to database.".to_string(), + ))?; self.pgb.write_message(&BeMessage::ReadyForQuery)?; - Ok(connstring) + Ok(dbinfo) } fn check_auth_md5(&self, md5_response: &[u8]) -> anyhow::Result<()> { @@ -208,13 +226,9 @@ databases without opening the browser. } } -async fn proxy_pass( - pgb: PostgresBackend, - proxy_addr: String, - connstr: String, -) -> anyhow::Result<()> { - let mut socket = tokio::net::TcpStream::connect(proxy_addr).await?; - let config = connstr.parse::()?; +async fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> { + let mut socket = tokio::net::TcpStream::connect(db_info.addr).await?; + let config = db_info.connstr.parse::()?; let _ = config.connect_raw(&mut socket, NoTls).await?; println!("Connected to pg, proxying"); diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 1f40ecedd2..4d8251b8c1 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -207,13 +207,9 @@ impl PostgresBackend { assert!(self.state == ProtoState::Authentication); - let (trailing_null, md5_response) = m.split_last().unwrap(); - - if *trailing_null != 0 { - let errmsg = "protocol violation"; - self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?; - bail!("auth failed: {}", errmsg); - } + let (_, md5_response) = m + .split_last() + .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; if let Err(e) = handler.check_auth_md5(self, md5_response) { self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;