[proxy] Minor cleanup

This commit is contained in:
Dmitry Ivanov
2021-10-22 17:09:59 +03:00
parent f8702d4625
commit 43ded1c54b
4 changed files with 26 additions and 29 deletions

View File

@@ -1,11 +1,7 @@
use anyhow::{bail, Context, Result};
use anyhow::{anyhow, bail, Context};
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
pub struct CPlaneApi {
auth_endpoint: &'static str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DatabaseInfo {
pub host: String,
@@ -23,13 +19,13 @@ pub struct ProxyAuthResult {
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> Result<SocketAddr> {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.ok_or_else(|| anyhow::Error::msg("cannot resolve at least one SocketAddr"))
.ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr"))
}
}
@@ -51,6 +47,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
}
}
pub struct CPlaneApi {
auth_endpoint: &'static str,
}
impl CPlaneApi {
pub fn new(auth_endpoint: &'static str) -> CPlaneApi {
CPlaneApi { auth_endpoint }
@@ -63,7 +63,7 @@ impl CPlaneApi {
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<ProxyAuthResult> {
) -> anyhow::Result<ProxyAuthResult> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", user)
@@ -76,13 +76,12 @@ impl CPlaneApi {
let resp = reqwest::blocking::get(url)?;
if resp.status().is_success() {
let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?;
println!("got auth info: #{:?}", auth_info);
Ok(auth_info)
} else {
bail!("Auth failed")
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
}
let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?;
println!("got auth info: #{:?}", auth_info);
Ok(auth_info)
}
}

View File

@@ -136,7 +136,7 @@ fn main() -> anyhow::Result<()> {
};
let state = ProxyState {
conf,
waiters: Mutex::new(HashMap::new()),
waiters: Default::default(),
};
let state: &'static ProxyState = Box::leak(Box::new(state));

View File

@@ -3,7 +3,7 @@ use std::{
thread,
};
use anyhow::bail;
use anyhow::{anyhow, bail};
use bytes::Bytes;
use serde::Deserialize;
use zenith_utils::{
@@ -105,7 +105,7 @@ fn try_process_query(
let sender = waiters
.get(&resp.session_id)
.ok_or_else(|| anyhow::Error::msg("psql_session_id is not found"))?;
.ok_or_else(|| anyhow!("psql_session_id is not found"))?;
match resp.result {
PsqlSessionResult::Success(db_info) => {
@@ -113,13 +113,13 @@ fn try_process_query(
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.flush()?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
PsqlSessionResult::Failure(message) => {
sender.send(Err(anyhow::Error::msg(message.clone())))?;
sender.send(Err(anyhow!(message.clone())))?;
bail!("psql session request failed: {}", message)
}

View File

@@ -322,6 +322,7 @@ impl PostgresBackend {
);
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupMessage(m) => {
trace!("got startup message {:?}", m);
@@ -330,12 +331,10 @@ impl PostgresBackend {
StartupRequestCode::NegotiateSsl => {
info!("SSL requested");
if self.tls_config.is_some() {
self.write_message(&BeMessage::EncryptionResponse(true))?;
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
} else {
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
}
StartupRequestCode::NegotiateGss => {
@@ -343,8 +342,7 @@ impl PostgresBackend {
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
StartupRequestCode::Normal => {
if self.tls_config.is_some() && !matches!(self.state, ProtoState::Encrypted)
{
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS".to_string(),
))?;
@@ -394,7 +392,7 @@ impl PostgresBackend {
AuthType::MD5 => {
let (_, md5_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
.ok_or_else(|| anyhow!("protocol violation"))?;
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
@@ -404,7 +402,7 @@ impl PostgresBackend {
AuthType::ZenithJWT => {
let (_, jwt_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
.ok_or_else(|| anyhow!("protocol violation"))?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;