mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-09 06:22:57 +00:00
[proxy] Minor cleanup
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
pub struct CPlaneApi {
|
||||
auth_endpoint: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
@@ -23,13 +19,13 @@ pub struct ProxyAuthResult {
|
||||
}
|
||||
|
||||
impl DatabaseInfo {
|
||||
pub fn socket_addr(&self) -> Result<SocketAddr> {
|
||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
let host_port = format!("{}:{}", self.host, self.port);
|
||||
host_port
|
||||
.to_socket_addrs()
|
||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::Error::msg("cannot resolve at least one SocketAddr"))
|
||||
.ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +47,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CPlaneApi {
|
||||
auth_endpoint: &'static str,
|
||||
}
|
||||
|
||||
impl CPlaneApi {
|
||||
pub fn new(auth_endpoint: &'static str) -> CPlaneApi {
|
||||
CPlaneApi { auth_endpoint }
|
||||
@@ -63,7 +63,7 @@ impl CPlaneApi {
|
||||
md5_response: &[u8],
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> Result<ProxyAuthResult> {
|
||||
) -> anyhow::Result<ProxyAuthResult> {
|
||||
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", user)
|
||||
@@ -76,13 +76,12 @@ impl CPlaneApi {
|
||||
|
||||
let resp = reqwest::blocking::get(url)?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
Ok(auth_info)
|
||||
} else {
|
||||
bail!("Auth failed")
|
||||
if !resp.status().is_success() {
|
||||
bail!("Auth failed: {}", resp.status())
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResult = serde_json::from_str(resp.text()?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
Ok(auth_info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let state = ProxyState {
|
||||
conf,
|
||||
waiters: Mutex::new(HashMap::new()),
|
||||
waiters: Default::default(),
|
||||
};
|
||||
let state: &'static ProxyState = Box::leak(Box::new(state));
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{
|
||||
thread,
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::{anyhow, bail};
|
||||
use bytes::Bytes;
|
||||
use serde::Deserialize;
|
||||
use zenith_utils::{
|
||||
@@ -105,7 +105,7 @@ fn try_process_query(
|
||||
|
||||
let sender = waiters
|
||||
.get(&resp.session_id)
|
||||
.ok_or_else(|| anyhow::Error::msg("psql_session_id is not found"))?;
|
||||
.ok_or_else(|| anyhow!("psql_session_id is not found"))?;
|
||||
|
||||
match resp.result {
|
||||
PsqlSessionResult::Success(db_info) => {
|
||||
@@ -113,13 +113,13 @@ fn try_process_query(
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.flush()?;
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
PsqlSessionResult::Failure(message) => {
|
||||
sender.send(Err(anyhow::Error::msg(message.clone())))?;
|
||||
sender.send(Err(anyhow!(message.clone())))?;
|
||||
|
||||
bail!("psql session request failed: {}", message)
|
||||
}
|
||||
|
||||
@@ -322,6 +322,7 @@ impl PostgresBackend {
|
||||
);
|
||||
}
|
||||
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupMessage(m) => {
|
||||
trace!("got startup message {:?}", m);
|
||||
@@ -330,12 +331,10 @@ impl PostgresBackend {
|
||||
StartupRequestCode::NegotiateSsl => {
|
||||
info!("SSL requested");
|
||||
|
||||
if self.tls_config.is_some() {
|
||||
self.write_message(&BeMessage::EncryptionResponse(true))?;
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls()?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
} else {
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::NegotiateGss => {
|
||||
@@ -343,8 +342,7 @@ impl PostgresBackend {
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
if self.tls_config.is_some() && !matches!(self.state, ProtoState::Encrypted)
|
||||
{
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS".to_string(),
|
||||
))?;
|
||||
@@ -394,7 +392,7 @@ impl PostgresBackend {
|
||||
AuthType::MD5 => {
|
||||
let (_, md5_response) = m
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
|
||||
.ok_or_else(|| anyhow!("protocol violation"))?;
|
||||
|
||||
if let Err(e) = handler.check_auth_md5(self, md5_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
|
||||
@@ -404,7 +402,7 @@ impl PostgresBackend {
|
||||
AuthType::ZenithJWT => {
|
||||
let (_, jwt_response) = m
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
|
||||
.ok_or_else(|| anyhow!("protocol violation"))?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
|
||||
|
||||
Reference in New Issue
Block a user