diff --git a/Cargo.lock b/Cargo.lock index f92bcb16d3..73f037995c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1424,7 +1424,9 @@ dependencies = [ "bytes", "clap", "hex", + "lazy_static", "md5", + "parking_lot", "rand", "reqwest", "rustls 0.19.1", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index d2e5c38f59..42287b04bb 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -9,12 +9,14 @@ edition = "2018" [dependencies] anyhow = "1.0" bytes = { version = "1.0.1", features = ['serde'] } +lazy_static = "1.4.0" md5 = "0.7.0" rand = "0.8.3" hex = "0.4.3" +parking_lot = "0.11.2" serde = "1" serde_json = "1" -tokio = "1.11" +tokio = { version = "1.11", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" } clap = "2.33.0" rustls = "0.19.1" diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1dd455a306..26936159d0 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,13 +1,45 @@ use crate::cplane_api::{CPlaneApi, DatabaseInfo}; use crate::ProxyState; use anyhow::{anyhow, bail}; -use std::net::TcpStream; +use lazy_static::lazy_static; +use parking_lot::Mutex; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::cell::Cell; +use std::collections::HashMap; +use std::net::{SocketAddr, TcpStream}; use std::{io, thread}; use tokio_postgres::NoTls; use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream}; use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *}; use zenith_utils::sock_split::{ReadStream, WriteStream}; +struct CancelClosure { + socket_addr: SocketAddr, + cancel_token: tokio_postgres::CancelToken, +} + +impl CancelClosure { + async fn try_cancel_query(&self) { + if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await { + // NOTE ignoring the result because: + // 1. This is a best effort attempt, the database doesn't have to listen + // 2. Being opaque about errors here helps avoid leaking info to unauthenticated user + let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await; + } + } +} + +lazy_static! { + // Enables serving CancelRequests + static ref CANCEL_MAP: Mutex> = Mutex::new(HashMap::new()); +} + +thread_local! { + // Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop. + static THREAD_CANCEL_KEY_DATA: Cell> = Cell::new(None); +} + /// /// Main proxy listener loop. /// @@ -22,12 +54,22 @@ pub fn thread_main( println!("accepted connection from {}", peer_addr); socket.set_nodelay(true).unwrap(); + // TODO Use a threadpool instead. Maybe use tokio's threadpool by + // spawning a future into its runtime. Tokio's JoinError should + // allow us to handle cleanup properly even if the future panics. thread::Builder::new() .name("Proxy thread".into()) .spawn(move || { if let Err(err) = proxy_conn_main(state, socket) { println!("error: {}", err); } + + // Clean up CANCEL_MAP. + THREAD_CANCEL_KEY_DATA.with(|cell| { + if let Some(cancel_key_data) = cell.get() { + CANCEL_MAP.lock().remove(&cancel_key_data); + }; + }); })?; } } @@ -51,7 +93,10 @@ pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow: )?, }; - let (client, server) = conn.handle_client()?; + let (client, server) = match conn.handle_client()? { + Some(x) => x, + None => return Ok(()), + }; let server = zenith_utils::sock_split::BidiStream::from_tcp(server); @@ -64,20 +109,25 @@ pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow: } impl ProxyConnection { - fn handle_client(mut self) -> anyhow::Result<(Stream, TcpStream)> { + /// Returns Ok(None) when connection was successfully closed. + fn handle_client(mut self) -> anyhow::Result> { let mut authenticate = || { - let (username, dbname) = self.handle_startup()?; + let (username, dbname) = match self.handle_startup()? { + Some(x) => x, + None => return Ok(None), + }; // Both scenarios here should end up producing database credentials if username.ends_with("@zenith") { - self.handle_existing_user(&username, &dbname) + self.handle_existing_user(&username, &dbname).map(Some) } else { - self.handle_new_user() + self.handle_new_user().map(Some) } }; let conn = match authenticate() { - Ok(db_info) => connect_to_db(db_info), + Ok(Some(db_info)) => connect_to_db(db_info), + Ok(None) => return Ok(None), Err(e) => { // Report the error to the client self.pgb.write_message(&Be::ErrorResponse(e.to_string()))?; @@ -91,7 +141,9 @@ impl ProxyConnection { .enable_all() .build()?; - let (pg_version, stream) = runtime.block_on(conn)?; + let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?; + self.pgb + .write_message(&BeMessage::BackendKeyData(cancel_key_data))?; let stream = stream.into_std()?; stream.set_nonblocking(false)?; @@ -105,10 +157,11 @@ impl ProxyConnection { ))? .write_message(&Be::ReadyForQuery)?; - Ok((self.pgb.into_stream(), db_stream)) + Ok(Some((self.pgb.into_stream(), db_stream))) } - fn handle_startup(&mut self) -> anyhow::Result<(String, String)> { + /// Returns Ok(None) when connection was successfully closed. + fn handle_startup(&mut self) -> anyhow::Result> { let have_tls = self.pgb.tls_config.is_some(); let mut encrypted = false; @@ -142,11 +195,17 @@ impl ProxyConnection { .ok_or_else(|| anyhow!("{} is missing in startup packet", key)) }; - return Ok((get_param("user")?, get_param("database")?)); + return Ok(Some((get_param("user")?, get_param("database")?))); } - // TODO: implement proper stmt cancellation - FeStartupPacket::CancelRequest { .. } => { - bail!("query cancellation is not supported") + FeStartupPacket::CancelRequest(cancel_key_data) => { + if let Some(cancel_closure) = CANCEL_MAP.lock().get(&cancel_key_data) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + runtime.block_on(cancel_closure.try_cancel_query()); + } + return Ok(None); } } } @@ -226,21 +285,43 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String { } /// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message -async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<(String, tokio::net::TcpStream)> { - let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?; +async fn connect_to_db( + db_info: DatabaseInfo, +) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> { + // Make raw connection. When connect_raw finishes we've received ReadyForQuery. + let socket_addr = db_info.socket_addr()?; + let mut socket = tokio::net::TcpStream::connect(socket_addr).await?; let config = tokio_postgres::Config::from(db_info); + // NOTE We effectively ignore some ParameterStatus and NoticeResponse + // messages here. Not sure if that could break something. let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; - let query = client.query_one("select current_setting('server_version')", &[]); + // Save info for potentially cancelling the query later + let mut rng = StdRng::from_entropy(); + let cancel_key_data = CancelKeyData { + // HACK We'd rather get the real backend_pid but tokio_postgres doesn't + // expose it and we don't want to do another roundtrip to query + // for it. The client will be able to notice that this is not the + // actual backend_pid, but backend_pid is not used for anything + // so it doesn't matter. + backend_pid: rng.gen(), + cancel_key: rng.gen(), + }; + let cancel_closure = CancelClosure { + socket_addr, + cancel_token: client.cancel_token(), + }; + CANCEL_MAP.lock().insert(cancel_key_data, cancel_closure); + THREAD_CANCEL_KEY_DATA.with(|cell| { + let prev_value = cell.replace(Some(cancel_key_data)); + assert!( + prev_value.is_none(), + "THREAD_CANCEL_KEY_DATA was already set" + ); + }); - tokio::pin!(query, conn); - - let version = tokio::select!( - x = query => x?.try_get(0)?, - _ = conn => bail!("connection closed too early"), - ); - - Ok((version, socket)) + let version = conn.parameter("server_version").unwrap(); + Ok((version.into(), socket, cancel_key_data)) } /// Concurrently proxy both directions of the client and server connections diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index d0fde2486e..3ad4f41ee2 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -48,7 +48,7 @@ pub enum FeStartupPacket { }, } -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub struct CancelKeyData { pub backend_pid: i32, pub cancel_key: i32,