From 9d6b78861d9a6e03f1238d5b565ba8d21afba5b2 Mon Sep 17 00:00:00 2001 From: Bojan Serafimov Date: Tue, 11 Jan 2022 01:31:37 -0500 Subject: [PATCH] WIP --- proxy/src/main.rs | 21 ++---- proxy/src/mgmt.rs | 2 +- proxy/src/proxy.rs | 158 +++++++++++++++++++++------------------------ 3 files changed, 82 insertions(+), 99 deletions(-) diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 8b397c4444..be983f1d9b 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -17,7 +17,8 @@ mod proxy; mod state; mod waiters; -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { let arg_matches = App::new("Zenith proxy/router") .version(GIT_VERSION) .arg( @@ -97,20 +98,10 @@ fn main() -> anyhow::Result<()> { println!("Starting mgmt on {}", state.conf.mgmt_address); let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?; - let threads = [ - // Spawn a thread to listen for connections. It will spawn further threads - // for each connection. - thread::Builder::new() - .name("Listener thread".into()) - .spawn(move || proxy::thread_main(state, pageserver_listener))?, - thread::Builder::new() - .name("Mgmt thread".into()) - .spawn(move || mgmt::thread_main(state, mgmt_listener))?, - ]; - - for t in threads { - t.join().unwrap()?; - } + tokio::try_join!( + proxy::thread_main(state, pageserver_listener), + mgmt::thread_main(state, mgmt_listener), + )?; Ok(()) } diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 1b9d9502f2..ca0fd2dbda 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -16,7 +16,7 @@ use crate::{cplane_api::DatabaseInfo, ProxyState}; /// /// Listens for connections, and launches a new handler thread for each. /// -pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> { +pub async fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> { loop { let (socket, peer_addr) = listener.accept()?; println!("accepted connection from {}", peer_addr); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 26936159d0..b911e10f2d 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -14,6 +14,7 @@ use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream}; use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *}; use zenith_utils::sock_split::{ReadStream, WriteStream}; +#[derive(Clone)] struct CancelClosure { socket_addr: SocketAddr, cancel_token: tokio_postgres::CancelToken, @@ -35,9 +36,14 @@ lazy_static! { static ref CANCEL_MAP: Mutex> = Mutex::new(HashMap::new()); } -thread_local! { - // Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop. - static THREAD_CANCEL_KEY_DATA: Cell> = Cell::new(None); +/// Create new CancelKeyData with backend_pid that doesn't necessarily +/// correspond to the backend_pid of any actual backend. +fn fabricate_cancel_key_data() -> CancelKeyData { + let mut rng = StdRng::from_entropy(); + CancelKeyData { + backend_pid: rng.gen(), + cancel_key: rng.gen(), + } } /// @@ -45,7 +51,7 @@ thread_local! { /// /// Listens for connections, and launches a new handler thread for each. /// -pub fn thread_main( +pub async fn thread_main( state: &'static ProxyState, listener: std::net::TcpListener, ) -> anyhow::Result<()> { @@ -54,23 +60,16 @@ pub fn thread_main( println!("accepted connection from {}", peer_addr); socket.set_nodelay(true).unwrap(); - // TODO Use a threadpool instead. Maybe use tokio's threadpool by - // spawning a future into its runtime. Tokio's JoinError should - // allow us to handle cleanup properly even if the future panics. - thread::Builder::new() - .name("Proxy thread".into()) - .spawn(move || { - if let Err(err) = proxy_conn_main(state, socket) { - println!("error: {}", err); - } - - // Clean up CANCEL_MAP. - THREAD_CANCEL_KEY_DATA.with(|cell| { - if let Some(cancel_key_data) = cell.get() { - CANCEL_MAP.lock().remove(&cancel_key_data); - }; - }); - })?; + tokio::task::spawn(async move { + let cancel_key_data = fabricate_cancel_key_data(); + let res = tokio::task::spawn(proxy_conn_main(state, socket, cancel_key_data)).await; + CANCEL_MAP.lock().remove(&cancel_key_data); + match res { + Err(join_err) => println!("join error: {}", join_err), + Ok(Err(conn_err)) => println!("connection error: {}", conn_err), + Ok(Ok(())) => {}, + } + }); } } @@ -81,7 +80,7 @@ struct ProxyConnection { pgb: PostgresBackend, } -pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> { +pub async fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream, cancel_key_data: CancelKeyData) -> anyhow::Result<()> { let conn = ProxyConnection { state, psql_session_id: hex::encode(rand::random::<[u8; 8]>()), @@ -93,7 +92,7 @@ pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow: )?, }; - let (client, server) = match conn.handle_client()? { + let (client, server) = match conn.handle_client(cancel_key_data).await? { Some(x) => x, None => return Ok(()), }; @@ -105,28 +104,41 @@ pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow: _ => panic!("invalid stream type"), }; - proxy(client.split(), server.split()) + proxy(client.split(), server.split()).await } impl ProxyConnection { /// Returns Ok(None) when connection was successfully closed. - fn handle_client(mut self) -> anyhow::Result> { - let mut authenticate = || { - let (username, dbname) = match self.handle_startup()? { - Some(x) => x, - None => return Ok(None), - }; + async fn handle_client(mut self, cancel_key_data: CancelKeyData) -> anyhow::Result> { + let (username, dbname) = match self.handle_startup().await? { + Some(x) => x, + None => return Ok(None), + }; - // Both scenarios here should end up producing database credentials - if username.ends_with("@zenith") { + let dbinfo = { + if true || username.ends_with("@zenith") { self.handle_existing_user(&username, &dbname).map(Some) } else { self.handle_new_user().map(Some) } }; - let conn = match authenticate() { - Ok(Some(db_info)) => connect_to_db(db_info), + // let mut authenticate = || async { + // let (username, dbname) = match self.handle_startup().await? { + // Some(x) => x, + // None => return Ok(None), + // }; + + // // Both scenarios here should end up producing database credentials + // if true || username.ends_with("@zenith") { + // self.handle_existing_user(&username, &dbname).map(Some) + // } else { + // self.handle_new_user().map(Some) + // } + // }; + + let conn = match dbinfo { + Ok(Some(info)) => connect_to_db(info), Ok(None) => return Ok(None), Err(e) => { // Report the error to the client @@ -137,11 +149,8 @@ impl ProxyConnection { // We'll get rid of this once migration to async is complete let (pg_version, db_stream) = { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?; + let (pg_version, stream, cancel_closure) = conn.await?; + CANCEL_MAP.lock().insert(cancel_key_data, cancel_closure); self.pgb .write_message(&BeMessage::BackendKeyData(cancel_key_data))?; let stream = stream.into_std()?; @@ -161,7 +170,7 @@ impl ProxyConnection { } /// Returns Ok(None) when connection was successfully closed. - fn handle_startup(&mut self) -> anyhow::Result> { + async fn handle_startup(&mut self) -> anyhow::Result> { let have_tls = self.pgb.tls_config.is_some(); let mut encrypted = false; @@ -198,12 +207,9 @@ impl ProxyConnection { return Ok(Some((get_param("user")?, get_param("database")?))); } FeStartupPacket::CancelRequest(cancel_key_data) => { - if let Some(cancel_closure) = CANCEL_MAP.lock().get(&cancel_key_data) { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - runtime.block_on(cancel_closure.try_cancel_query()); + let entry = CANCEL_MAP.lock().get(&cancel_key_data).map(core::clone::Clone::clone); + if let Some(cancel_closure) = entry { + cancel_closure.try_cancel_query().await; } return Ok(None); } @@ -231,14 +237,21 @@ impl ProxyConnection { .split_last() .ok_or_else(|| anyhow!("unexpected password message"))?; - let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); - let db_info = cplane.authenticate_proxy_request( - user, - db, - md5_response, - &md5_salt, - &self.psql_session_id, - )?; + let db_info = DatabaseInfo { + host: "localhost".into(), + port: 5432, + dbname: "postgres".into(), + user: "postgres".into(), + password: Some("postgres".into()), + }; + // let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); + // let db_info = cplane.authenticate_proxy_request( + // user, + // db, + // md5_response, + // &md5_salt, + // &self.psql_session_id, + // )?; self.pgb .write_message_noflush(&Be::AuthenticationOk)? @@ -287,7 +300,7 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String { /// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message async fn connect_to_db( db_info: DatabaseInfo, -) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> { +) -> anyhow::Result<(String, tokio::net::TcpStream, CancelClosure)> { // Make raw connection. When connect_raw finishes we've received ReadyForQuery. let socket_addr = db_info.socket_addr()?; let mut socket = tokio::net::TcpStream::connect(socket_addr).await?; @@ -295,41 +308,21 @@ async fn connect_to_db( // NOTE We effectively ignore some ParameterStatus and NoticeResponse // messages here. Not sure if that could break something. let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; - - // Save info for potentially cancelling the query later - let mut rng = StdRng::from_entropy(); - let cancel_key_data = CancelKeyData { - // HACK We'd rather get the real backend_pid but tokio_postgres doesn't - // expose it and we don't want to do another roundtrip to query - // for it. The client will be able to notice that this is not the - // actual backend_pid, but backend_pid is not used for anything - // so it doesn't matter. - backend_pid: rng.gen(), - cancel_key: rng.gen(), - }; let cancel_closure = CancelClosure { socket_addr, cancel_token: client.cancel_token(), }; - CANCEL_MAP.lock().insert(cancel_key_data, cancel_closure); - THREAD_CANCEL_KEY_DATA.with(|cell| { - let prev_value = cell.replace(Some(cancel_key_data)); - assert!( - prev_value.is_none(), - "THREAD_CANCEL_KEY_DATA was already set" - ); - }); let version = conn.parameter("server_version").unwrap(); - Ok((version.into(), socket, cancel_key_data)) + Ok((version.into(), socket, cancel_closure)) } /// Concurrently proxy both directions of the client and server connections -fn proxy( +async fn proxy( (client_read, client_write): (ReadStream, WriteStream), (server_read, server_write): (ReadStream, WriteStream), ) -> anyhow::Result<()> { - fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result { + async fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result { /// FlushWriter will make sure that every message is sent as soon as possible struct FlushWriter(W); @@ -354,10 +347,9 @@ fn proxy( res } - let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write)); - - do_proxy(server_read, client_write)?; - client_to_server_jh.join().unwrap()?; - + tokio::try_join!( + do_proxy(client_read, server_write), + do_proxy(server_read, client_write), + )?; Ok(()) }