diff --git a/Cargo.lock b/Cargo.lock index c2248c227b..f0c3242e46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1129,8 +1129,8 @@ dependencies = [ "lazy_static", "log", "postgres", - "postgres-protocol", - "postgres-types", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "postgres_ffi", "rand", "regex", @@ -1237,9 +1237,27 @@ dependencies = [ "fallible-iterator", "futures", "log", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "tokio", - "tokio-postgres", + "tokio-postgres 0.7.1", +] + +[[package]] +name = "postgres-protocol" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff3e0f70d32e20923cabf2df02913be7c1842d4c772db8065c00fcfdd1d1bff3" +dependencies = [ + "base64", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", ] [[package]] @@ -1260,6 +1278,17 @@ dependencies = [ "stringprep", ] +[[package]] +name = "postgres-types" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "430f4131e1b7657b0cd9a2b0c3408d77c9a43a042d300b8c77f981dffcc43a2f" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "postgres-types" version = "0.2.1" @@ -1267,7 +1296,7 @@ source = "git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b dependencies = [ "bytes", "fallible-iterator", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", ] [[package]] @@ -1323,7 +1352,11 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes", + "hex", "md5", + "rand", + "tokio", + "tokio-postgres 0.7.2", "zenith_utils", ] @@ -1996,8 +2029,31 @@ dependencies = [ "percent-encoding", "phf", "pin-project-lite", - "postgres-protocol", - "postgres-types", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "socket2", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-postgres" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2b1383c7e4fb9a09e292c7c6afb7da54418d53b045f1c1fac7a911411a2b8b" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", + "postgres-types 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "socket2", "tokio", "tokio-util", @@ -2182,7 +2238,7 @@ dependencies = [ "log", "pageserver", "postgres", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "postgres_ffi", "regex", "rust-s3", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index cf9a69a895..b6416e2586 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -10,5 +10,10 @@ edition = "2018" anyhow = "1.0" bytes = { version = "1.0.1", features = ['serde'] } md5 = "0.7.0" +rand = "0.8.3" +hex = "0.4.3" + +tokio = "1.7.1" +tokio-postgres = "0.7.2" zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index ae544e98dd..03e5c2dd53 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -44,11 +44,11 @@ impl CPlaneApi { } } - fn get_database_uri(_user: String, _database: String) -> Option { - Some("postgresql://localhost/stas".to_string()) + pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result { + Ok("user=stas dbname=stas".to_string()) } - fn create_database(_user: String, _database: String) -> Option { - Some("postgresql://localhost/stas".to_string()) + pub fn create_database(&self, _user: &String, _database: &String) -> Result { + Ok("user=stas dbname=stas".to_string()) } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index a511a7c8bd..dc013d9f92 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -11,6 +11,7 @@ use std::{ }; mod cplane_api; +mod mgmt; mod proxy; pub struct ProxyConf { @@ -37,11 +38,28 @@ fn main() -> anyhow::Result<()> { println!("Starting proxy on {}", conf.proxy_address); let pageserver_listener = TcpListener::bind(conf.proxy_address)?; + println!("Starting mgmt on {}", conf.mgmt_address); + let mgmt_listener = TcpListener::bind(conf.mgmt_address)?; + + let mut threads = Vec::new(); + // Spawn a thread to listen for connections. It will spawn further threads // for each connection. - let proxy_listener_thread = thread::Builder::new() - .name("Proxy thread".into()) - .spawn(move || proxy::thread_main(&conf, pageserver_listener))?; + threads.push( + thread::Builder::new() + .name("Proxy thread".into()) + .spawn(move || proxy::thread_main(&conf, pageserver_listener))?, + ); - proxy_listener_thread.join().unwrap() + threads.push( + thread::Builder::new() + .name("Mgmt thread".into()) + .spawn(move || mgmt::thread_main(&conf, mgmt_listener))?, + ); + + for t in threads { + let _ = t.join().unwrap(); + } + + Ok(()) } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9a5c1e6b62..74f32fe4c0 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,22 +1,24 @@ use crate::{cplane_api::CPlaneApi, ProxyConf}; -use bytes::Bytes; -use std::{ - net::{TcpListener, TcpStream}, - thread, -}; -use zenith_utils::postgres_backend::{AuthType, PostgresBackend}; -use zenith_utils::{ - postgres_backend, - pq_proto::{BeMessage, SINGLE_COL_ROWDESC}, -}; +use anyhow::bail; +use tokio_postgres::NoTls; + +use rand::Rng; +use std::thread; +use tokio::io::AsyncWriteExt; +use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; +use zenith_utils::pq_proto::*; +use zenith_utils::{postgres_backend, pq_proto::BeMessage}; /// /// Main proxy listener loop. /// /// Listens for connections, and launches a new handler thread for each. /// -pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> { +pub fn thread_main( + conf: &'static ProxyConf, + listener: std::net::TcpListener, +) -> anyhow::Result<()> { loop { let (socket, peer_addr) = listener.accept()?; println!("accepted connection from {}", peer_addr); @@ -30,92 +32,211 @@ pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::R } } -pub fn proxy_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> { - let mut conn_handler = ProxyHandler { - conf, - existing_user: false, - cplane: CPlaneApi::new(&conf.cplane_address), - user: "".into(), - database: "".into(), - }; - let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::Trust)?; - pgbackend.run(&mut conn_handler) -} - -struct ProxyHandler { +struct ProxyConnection { conf: &'static ProxyConf, existing_user: bool, cplane: CPlaneApi, user: String, database: String, + + pgb: PostgresBackend, + md5_salt: [u8; 4], } -// impl ProxyHandler { -// } +pub fn proxy_conn_main( + conf: &'static ProxyConf, + socket: std::net::TcpStream, +) -> anyhow::Result<()> { + let mut conn = ProxyConnection { + conf, + existing_user: false, + cplane: CPlaneApi::new(&conf.cplane_address), + user: "".into(), + database: "".into(), + pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?, + md5_salt: [0u8; 4], + }; -impl postgres_backend::Handler for ProxyHandler { - fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: Bytes, - ) -> anyhow::Result<()> { - println!("Got query: {:?}", query_string); + // Check StartupMessage + // This will set conn.existing_user and we can decide on next actions + conn.handle_startup()?; - if !self.existing_user { - pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? - .write_message_noflush(&BeMessage::DataRow(&[Some(b"new user scenario")]))? - .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; - } else { - pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? - .write_message_noflush(&BeMessage::DataRow(&[Some(b"existing user")]))? - .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + // both scenarious here should end up producing database connection string + let database_uri = if conn.existing_user { + conn.handle_existing_user()? + } else { + conn.handle_new_user()? + }; + + // ok, proxy pass user connection to database_uri + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let _ = runtime.block_on(proxy_pass( + conn.pgb, + "127.0.0.1:5432".to_string(), + database_uri, + ))?; + + println!("proxy_conn_main done;"); + + Ok(()) +} + +impl ProxyConnection { + fn handle_startup(&mut self) -> anyhow::Result<()> { + loop { + let msg = self.pgb.read_message()?; + println!("got message {:?}", msg); + match msg { + Some(FeMessage::StartupMessage(m)) => { + println!("got startup message {:?}", m); + + match m.kind { + StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + println!("SSL requested"); + self.pgb.write_message(&BeMessage::Negotiate)?; + } + StartupRequestCode::Normal => { + self.user = m + .params + .get("user") + .ok_or_else(|| { + anyhow::Error::msg("user is required in startup packet") + })? + .into(); + self.database = m + .params + .get("database") + .ok_or_else(|| { + anyhow::Error::msg("database is required in startup packet") + })? + .into(); + + self.existing_user = self.user.ends_with("@zenith"); + + break; + } + StartupRequestCode::Cancel => break, + } + } + None => { + bail!("connection closed") + } + unexpected => { + bail!("unexpected message type : {:?}", unexpected) + } + } } - - pgb.flush()?; Ok(()) } - fn startup( - &mut self, - pgb: &mut PostgresBackend, - sm: &zenith_utils::pq_proto::FeStartupMessage, - ) -> anyhow::Result<()> { - println!("Got startup: {:?}", sm); + fn handle_existing_user(&mut self) -> anyhow::Result { + // ask password + rand::thread_rng().fill(&mut self.md5_salt); + self.pgb + .write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?; + self.pgb.state = ProtoState::Authentication; - self.user = sm - .params - .get("user") - .ok_or_else(|| anyhow::Error::msg("user is required in startup packet"))? - .into(); - self.database = sm - .params - .get("database") - .ok_or_else(|| anyhow::Error::msg("database is required in startup packet"))? - .into(); + // check password + println!("handle_existing_user"); + let msg = self.pgb.read_message()?; + println!("got message {:?}", msg); + if let Some(FeMessage::PasswordMessage(m)) = msg { + println!("got password message '{:?}'", m); - // We use '@zenith' in username as an indicator that user already created - // this database and not logging in with his system username. - // - // With that approach we can create new databases on demand with something like - // psql -h zenith.tech -U stas@zenith my_new_db (assuming .pgpass is set). That is - // especially helpful if one is setting configuration files for some app that requires - // database -- he can just fill config and run initial migration without any other actions. - if self.user.ends_with("@zenith") { - pgb.auth_type = AuthType::MD5; - self.existing_user = true; + assert!(self.existing_user); + + let (_trailing_null, md5_response) = m + .split_last() + .ok_or_else(|| anyhow::Error::msg("unexpected password message"))?; + + if let Err(e) = self.check_auth_md5(md5_response) { + self.pgb + .write_message(&BeMessage::ErrorResponse(format!("{}", e)))?; + bail!("auth failed: {}", e); + } else { + self.pgb + .write_message_noflush(&BeMessage::AuthenticationOk)?; + self.pgb + .write_message_noflush(&BeMessage::ParameterStatus)?; + self.pgb.write_message(&BeMessage::ReadyForQuery)?; + } } - Ok(()) + // ok, we are authorized + self.cplane.get_database_uri(&self.user, &self.database) } - fn check_auth_md5( - &mut self, - pgb: &mut PostgresBackend, - md5_response: &[u8], - ) -> anyhow::Result<()> { + fn handle_new_user(&mut self) -> anyhow::Result { + let mut reg_id_buf = [0u8; 8]; + rand::thread_rng().fill(&mut reg_id_buf); + let reg_id = hex::encode(reg_id_buf); + + let hello_message = format!("☀️ Welcome to Zenith! + +To proceed with database creation open following link: + + https://console.zenith.tech/claim_db/{} + +It needed to be done once and we will send you '.pgpass' file which will allow you to access or create +databases without opening the browser. + +", reg_id); + + self.pgb + .write_message_noflush(&BeMessage::AuthenticationOk)?; + self.pgb + .write_message_noflush(&BeMessage::ParameterStatus)?; + self.pgb + .write_message(&BeMessage::NoticeResponse(hello_message.to_string()))?; + + // await for database creation + let connstring = self.cplane.get_database_uri(&self.user, &self.database)?; + self.pgb.write_message(&BeMessage::ReadyForQuery)?; + + Ok(connstring) + } + + fn check_auth_md5(&self, md5_response: &[u8]) -> anyhow::Result<()> { assert!(self.existing_user); self.cplane - .check_auth(self.user.as_str(), md5_response, &pgb.md5_salt) + .check_auth(self.user.as_str(), md5_response, &self.md5_salt) } } + +async fn proxy_pass( + pgb: PostgresBackend, + proxy_addr: String, + connstr: String, +) -> anyhow::Result<()> { + let mut socket = tokio::net::TcpStream::connect(proxy_addr).await?; + let config = connstr.parse::()?; + let _ = config.connect_raw(&mut socket, NoTls).await?; + + println!("Connected to pg, proxying"); + + let incoming_std = pgb.into_stream(); + incoming_std.set_nonblocking(true)?; + let mut incoming_conn = tokio::net::TcpStream::from_std(incoming_std)?; + + let (mut ri, mut wi) = incoming_conn.split(); + let (mut ro, mut wo) = socket.split(); + + let client_to_server = async { + tokio::io::copy(&mut ri, &mut wo).await?; + wo.shutdown().await + }; + + let server_to_client = async { + tokio::io::copy(&mut ro, &mut wi).await?; + wi.shutdown().await + }; + + tokio::try_join!(client_to_server, server_to_client)?; + + Ok(()) +} diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 9467dcbb69..1f40ecedd2 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -11,7 +11,7 @@ use log::*; use rand::Rng; use std::io; use std::io::{BufReader, Write}; -use std::net::{Shutdown, TcpStream}; +use std::net::TcpStream; pub trait Handler { /// Handle single query. @@ -36,7 +36,7 @@ pub trait Handler { } #[derive(PartialEq)] -enum ProtoState { +pub enum ProtoState { Initialization, Authentication, Established, @@ -57,20 +57,23 @@ pub struct PostgresBackend { // Output buffer. c.f. BeMessage::write why we are using BytesMut here. buf_out: BytesMut, - state: ProtoState, + pub state: ProtoState, - pub md5_salt: [u8; 4], - pub auth_type: AuthType, + md5_salt: [u8; 4], + auth_type: AuthType, } -// In replication.rs a separate thread is reading keepalives from the -// socket. When main one finishes, tell it to get down by shutdowning the -// socket. -impl Drop for PostgresBackend { - fn drop(&mut self) { - let _res = self.stream_out.shutdown(Shutdown::Both); - } -} +// TODO: call shutdown() manually. +// into_shtm() methods do not work with types implementing Drop + +// // In replication.rs a separate thread is reading keepalives from the +// // socket. When main one finishes, tell it to get down by shutdowning the +// // socket. +// impl Drop for PostgresBackend { +// fn drop(&mut self) { +// let _res = self.stream_out.shutdown(Shutdown::Both); +// } +// } impl PostgresBackend { pub fn new(socket: TcpStream, auth_type: AuthType) -> Result { @@ -94,6 +97,10 @@ impl PostgresBackend { Ok(pb) } + pub fn into_stream(self) -> TcpStream { + self.stream_out + } + /// Get direct reference (into the Option) to the read stream. fn get_stream_in(&mut self) -> Result<&mut BufReader> { match self.stream_in { @@ -172,20 +179,23 @@ impl PostgresBackend { // to bypass auth for new users. handler.startup(self, &m)?; - if self.auth_type == AuthType::Trust { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } else { - rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt.clone(); - self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, - ))?; - self.state = ProtoState::Authentication; + match self.auth_type { + AuthType::Trust => { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.state = ProtoState::Established; + } + AuthType::MD5 => { + rand::thread_rng().fill(&mut self.md5_salt); + let md5_salt = self.md5_salt.clone(); + self.write_message(&BeMessage::AuthenticationMD5Password( + &md5_salt, + ))?; + self.state = ProtoState::Authentication; + } } } StartupRequestCode::Cancel => break, diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index ac6e3bb971..a513bcbc0b 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -331,6 +331,7 @@ pub enum BeMessage<'a> { ReadyForQuery, RowDescription(&'a [RowDescriptor<'a>]), XLogData(XLogDataBody<'a>), + NoticeResponse(String), } // One row desciption in RowDescription packet. @@ -572,6 +573,30 @@ impl<'a> BeMessage<'a> { .unwrap(); } + // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the + // message but continue listening for ReadyForQuery or ErrorResponse" + BeMessage::NoticeResponse(error_msg) => { + // For all the errors set Severity to Error and error code to + // 'internal error'. + + // 'N' signalizes NoticeResponse messages + buf.put_u8(b'N'); + write_body(buf, |buf| { + buf.put_u8(b'S'); // severity + write_cstr(&Bytes::from("NOTICE"), buf)?; + + buf.put_u8(b'C'); // SQLSTATE error code + write_cstr(&Bytes::from("CXX000"), buf)?; + + buf.put_u8(b'M'); // the message + write_cstr(error_msg.as_bytes(), buf)?; + + buf.put_u8(0); // terminator + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + BeMessage::NoData => { buf.put_u8(b'n'); write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();