This commit is contained in:
Conrad Ludgate
2024-09-13 10:12:22 +01:00
parent d418cf2dde
commit 0924267612
6 changed files with 68 additions and 56 deletions

View File

@@ -104,22 +104,23 @@ async fn auth_quirks(
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
let (info) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client).await?;
todo!()
// let res = hacks::password_hack_no_authentication(info, client).await?;
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
unreachable!("password hack should return a password")
}
};
(res.info, Some(password))
// let password = match res.keys {
// ComputeCredentialKeys::Password(p) => p,
// ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
// unreachable!("password hack should return a password")
// }
// };
// (res.info, Some(password))
}
Ok(info) => (info, None),
Ok(info) => info,
};
info!("fetching user's authentication info");
dbg!("fetching user's authentication info");
let cached_secret = api
.get_role_secret(&RequestMonitoring::test(), &info)
.await?;
@@ -132,11 +133,11 @@ async fn auth_quirks(
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
dbg!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
};
match authenticate_with_secret(secret, info, client, unauthenticated_password, config).await {
match authenticate_with_secret(secret, info, client, config).await {
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_auth_failed() {
@@ -152,27 +153,27 @@ async fn authenticate_with_secret(
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut AuthProxyStream,
unauthenticated_password: Option<Vec<u8>>,
// unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
// if let Some(password) = unauthenticated_password {
// let ep = EndpointIdInt::from(&info.endpoint);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.user));
}
};
// let auth_outcome =
// validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
// let keys = match auth_outcome {
// crate::sasl::Outcome::Success(key) => key,
// crate::sasl::Outcome::Failure(reason) => {
// info!("auth backend failed with an error: {reason}");
// return Err(auth::AuthError::auth_failed(&*info.user));
// }
// };
// we have authenticated the password
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
// // we have authenticated the password
// client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
return Ok(ComputeCredentials { info, keys });
}
// return Ok(ComputeCredentials { info, keys });
// }
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(info, client, config, secret).await
@@ -193,6 +194,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::Console(api, user_info) => {
dbg!("authenticating...");
info!(
user = &*user_info.user,
project = user_info.endpoint(),
@@ -204,7 +206,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
}
};
info!("user successfully authenticated");
dbg!("user successfully authenticated");
Ok(res)
}
}

View File

@@ -23,7 +23,7 @@ pub(super) async fn authenticate(
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
dbg!("auth endpoint chooses SCRAM");
let scram = auth_proxy::Scram(&secret);
let auth_outcome = tokio::time::timeout(

View File

@@ -81,6 +81,7 @@ impl<'a> AuthFlow<'a, Begin> {
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, M>> {
dbg!("sending auth begin message");
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;

View File

@@ -33,7 +33,7 @@ struct ProxyCliArgs {
#[clap(
short,
long,
default_value = "http://localhost:3000/authenticate_proxy_request/"
default_value = "http://localhost:3000/authenticate_proxy_request"
)]
auth_endpoint: String,
/// timeout for the TLS handshake
@@ -170,7 +170,7 @@ async fn main() {
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let config = Box::leak(Box::new(AuthProxyConfig { backend, auth }));
let config = &*Box::leak(Box::new(AuthProxyConfig { backend, auth }));
loop {
select! {
@@ -183,7 +183,9 @@ async fn main() {
}
stream = conn.accept_bi() => {
let (send, recv) = stream.unwrap();
tasks.spawn(handle_stream(config, send, recv));
tasks.spawn(async move {
handle_stream(config, send, recv).await.inspect_err(|e| println!("err {e:?}"))
});
}
}
}

View File

@@ -6,7 +6,7 @@ use std::{
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::{Buf, BufMut, BytesMut};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use itertools::Itertools;
@@ -446,9 +446,11 @@ impl PglbConn<AuthPassthrough> {
};
match msg? {
PglbMessage::Postgres(mut payload) => {
println!("msg {payload:?}");
let Some(msg) = PgRawMessage::decode(&mut payload, false)? else {
bail!("auth proxy sent invalid message");
};
println!("parsed msg {msg:?}");
client_stream.send(msg).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
@@ -649,10 +651,11 @@ impl tokio_util::codec::Decoder for PgRawCodec {
}
}
#[derive(Debug)]
pub enum PgRawMessage {
SslRequest,
Start(Vec<u8>),
Generic { tag: u8, payload: Vec<u8> },
Start(Bytes),
Generic { tag: u8, payload: Bytes },
}
impl PgRawMessage {
@@ -663,13 +666,13 @@ impl PgRawMessage {
dst.put_u32(80877103);
}
Self::Start(payload) => {
dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(&payload);
// dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(payload);
}
Self::Generic { tag, payload } => {
dst.put_u8(*tag);
dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(&payload);
// dst.put_u8(*tag);
// dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(payload);
}
}
Ok(())
@@ -689,7 +692,7 @@ impl PgRawMessage {
if length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
Ok(Some(PgRawMessage::SslRequest))
} else {
Ok(Some(PgRawMessage::Start(src.split_to(length).to_vec())))
Ok(Some(PgRawMessage::Start(src.split_to(length).freeze())))
}
} else {
if src.remaining() < 5 {
@@ -697,14 +700,14 @@ impl PgRawMessage {
return Ok(None);
}
let tag = src[0];
let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize - 1;
let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
Ok(Some(PgRawMessage::Generic {
tag,
payload: src.split_to(length).to_vec(),
payload: src.split_to(length).freeze(),
}))
}
}

View File

@@ -27,6 +27,7 @@ use tokio::io::join;
use tokio_postgres::config::AuthKeys;
use tokio_util::codec::Framed;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeCredentials;
use crate::auth_proxy::AuthProxyStream;
use crate::auth_proxy::TLS_SERVER_END_POINT;
@@ -55,6 +56,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
@@ -493,14 +495,13 @@ pub async fn handle_stream(
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
let socket: SocketAddr = node_info.config.get_host()?.parse()?;
println!("woke compute");
let addr: IpAddr = node_info.config.get_host()?.parse()?;
let socket = SocketAddr::new(addr, node_info.config.get_ports()[0]);
// tell pglb that the compute is up
stream
.write_message(&pq_proto::BeMessage::AuthenticationOk)
.await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
socket,
@@ -512,7 +513,7 @@ pub async fn handle_stream(
frontend::startup_message(params.iter(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
auth_with_compute(&mut stream, &node_info).await?;
auth_with_compute(&mut stream, user_info.get_keys()).await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
@@ -527,6 +528,8 @@ async fn auth_with_user(
conn_info: &ConnectionInitiatedPayload,
params: &StartupMessageParams,
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
dbg!("auth...");
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").context("missing user")?.into(),
@@ -537,6 +540,8 @@ async fn auth_with_user(
options: NeonOptions::parse_params(params),
};
dbg!("parsed used info");
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
@@ -558,12 +563,11 @@ async fn auth_with_user(
async fn auth_with_compute(
stream: &mut AuthProxyStream,
node_info: &NodeInfo,
keys: &ComputeCredentialKeys,
) -> anyhow::Result<()> {
let AuthKeys::ScramSha256(scram_keys) = node_info
.config
.get_auth_keys()
.context("missing auth keys")?;
let ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(scram_keys)) = keys else {
bail!("missing keys");
};
// compute offers sasl
stream
@@ -576,7 +580,7 @@ async fn auth_with_compute(
let mut buf = BytesMut::new();
// send auth message
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
let mut scram = ScramSha256::new_with_keys(*scram_keys, ChannelBinding::unsupported());
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;