diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 2579b9a73d..153f54b564 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -1,11 +1,7 @@ -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Context}; use serde::{Deserialize, Serialize}; use std::net::{SocketAddr, ToSocketAddrs}; -pub struct CPlaneApi { - auth_endpoint: &'static str, -} - #[derive(Serialize, Deserialize, Debug)] pub struct DatabaseInfo { pub host: String, @@ -23,13 +19,13 @@ pub struct ProxyAuthResult { } impl DatabaseInfo { - pub fn socket_addr(&self) -> Result { + pub fn socket_addr(&self) -> anyhow::Result { let host_port = format!("{}:{}", self.host, self.port); host_port .to_socket_addrs() .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? .next() - .ok_or_else(|| anyhow::Error::msg("cannot resolve at least one SocketAddr")) + .ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr")) } } @@ -51,6 +47,10 @@ impl From for tokio_postgres::Config { } } +pub struct CPlaneApi { + auth_endpoint: &'static str, +} + impl CPlaneApi { pub fn new(auth_endpoint: &'static str) -> CPlaneApi { CPlaneApi { auth_endpoint } @@ -63,7 +63,7 @@ impl CPlaneApi { md5_response: &[u8], salt: &[u8; 4], psql_session_id: &str, - ) -> Result { + ) -> anyhow::Result { let mut url = reqwest::Url::parse(self.auth_endpoint)?; url.query_pairs_mut() .append_pair("login", user) @@ -76,13 +76,12 @@ impl CPlaneApi { let resp = reqwest::blocking::get(url)?; - if resp.status().is_success() { - let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?; - println!("got auth info: #{:?}", auth_info); - - Ok(auth_info) - } else { - bail!("Auth failed") + if !resp.status().is_success() { + bail!("Auth failed: {}", resp.status()) } + + let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?; + println!("got auth info: #{:?}", auth_info); + Ok(auth_info) } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index b435b82a99..fea44f93ea 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -136,7 +136,7 @@ fn main() -> anyhow::Result<()> { }; let state = ProxyState { conf, - waiters: Mutex::new(HashMap::new()), + waiters: Default::default(), }; let state: &'static ProxyState = Box::leak(Box::new(state)); diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 1f33b68a1c..6b1e1a3134 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -3,7 +3,7 @@ use std::{ thread, }; -use anyhow::bail; +use anyhow::{anyhow, bail}; use bytes::Bytes; use serde::Deserialize; use zenith_utils::{ @@ -105,7 +105,7 @@ fn try_process_query( let sender = waiters .get(&resp.session_id) - .ok_or_else(|| anyhow::Error::msg("psql_session_id is not found"))?; + .ok_or_else(|| anyhow!("psql_session_id is not found"))?; match resp.result { PsqlSessionResult::Success(db_info) => { @@ -113,13 +113,13 @@ fn try_process_query( pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? - .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; - pgb.flush()?; + .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + Ok(()) } PsqlSessionResult::Failure(message) => { - sender.send(Err(anyhow::Error::msg(message.clone())))?; + sender.send(Err(anyhow!(message.clone())))?; bail!("psql session request failed: {}", message) } diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 28636d7811..4afcf2a554 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -322,6 +322,7 @@ impl PostgresBackend { ); } + let have_tls = self.tls_config.is_some(); match msg { FeMessage::StartupMessage(m) => { trace!("got startup message {:?}", m); @@ -330,12 +331,10 @@ impl PostgresBackend { StartupRequestCode::NegotiateSsl => { info!("SSL requested"); - if self.tls_config.is_some() { - self.write_message(&BeMessage::EncryptionResponse(true))?; + self.write_message(&BeMessage::EncryptionResponse(have_tls))?; + if have_tls { self.start_tls()?; self.state = ProtoState::Encrypted; - } else { - self.write_message(&BeMessage::EncryptionResponse(false))?; } } StartupRequestCode::NegotiateGss => { @@ -343,8 +342,7 @@ impl PostgresBackend { self.write_message(&BeMessage::EncryptionResponse(false))?; } StartupRequestCode::Normal => { - if self.tls_config.is_some() && !matches!(self.state, ProtoState::Encrypted) - { + if have_tls && !matches!(self.state, ProtoState::Encrypted) { self.write_message(&BeMessage::ErrorResponse( "must connect with TLS".to_string(), ))?; @@ -394,7 +392,7 @@ impl PostgresBackend { AuthType::MD5 => { let (_, md5_response) = m .split_last() - .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; + .ok_or_else(|| anyhow!("protocol violation"))?; if let Err(e) = handler.check_auth_md5(self, md5_response) { self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?; @@ -404,7 +402,7 @@ impl PostgresBackend { AuthType::ZenithJWT => { let (_, jwt_response) = m .split_last() - .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; + .ok_or_else(|| anyhow!("protocol violation"))?; if let Err(e) = handler.check_auth_jwt(self, jwt_response) { self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;