diff --git a/proxy/src/auth_proxy/backend.rs b/proxy/src/auth_proxy/backend.rs index fed645c33d..5c1b2eea40 100644 --- a/proxy/src/auth_proxy/backend.rs +++ b/proxy/src/auth_proxy/backend.rs @@ -104,22 +104,23 @@ async fn auth_quirks( // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. - let (info, unauthenticated_password) = match user_info.try_into() { + let (info) = match user_info.try_into() { Err(info) => { - let res = hacks::password_hack_no_authentication(info, client).await?; + todo!() + // let res = hacks::password_hack_no_authentication(info, client).await?; - let password = match res.keys { - ComputeCredentialKeys::Password(p) => p, - ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => { - unreachable!("password hack should return a password") - } - }; - (res.info, Some(password)) + // let password = match res.keys { + // ComputeCredentialKeys::Password(p) => p, + // ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => { + // unreachable!("password hack should return a password") + // } + // }; + // (res.info, Some(password)) } - Ok(info) => (info, None), + Ok(info) => info, }; - info!("fetching user's authentication info"); + dbg!("fetching user's authentication info"); let cached_secret = api .get_role_secret(&RequestMonitoring::test(), &info) .await?; @@ -132,11 +133,11 @@ async fn auth_quirks( // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); + dbg!("authentication info not found, mocking it"); AuthSecret::Scram(scram::ServerSecret::mock(rand::random())) }; - match authenticate_with_secret(secret, info, client, unauthenticated_password, config).await { + match authenticate_with_secret(secret, info, client, config).await { Ok(keys) => Ok(keys), Err(e) => { if e.is_auth_failed() { @@ -152,27 +153,27 @@ async fn authenticate_with_secret( secret: AuthSecret, info: ComputeUserInfo, client: &mut AuthProxyStream, - unauthenticated_password: Option>, + // unauthenticated_password: Option>, config: &'static AuthenticationConfig, ) -> auth::Result { - if let Some(password) = unauthenticated_password { - let ep = EndpointIdInt::from(&info.endpoint); + // if let Some(password) = unauthenticated_password { + // let ep = EndpointIdInt::from(&info.endpoint); - let auth_outcome = - validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; - let keys = match auth_outcome { - crate::sasl::Outcome::Success(key) => key, - crate::sasl::Outcome::Failure(reason) => { - info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(&*info.user)); - } - }; + // let auth_outcome = + // validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + // let keys = match auth_outcome { + // crate::sasl::Outcome::Success(key) => key, + // crate::sasl::Outcome::Failure(reason) => { + // info!("auth backend failed with an error: {reason}"); + // return Err(auth::AuthError::auth_failed(&*info.user)); + // } + // }; - // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + // // we have authenticated the password + // client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; - return Ok(ComputeCredentials { info, keys }); - } + // return Ok(ComputeCredentials { info, keys }); + // } // Finally, proceed with the main auth flow (SCRAM-based). classic::authenticate(info, client, config, secret).await @@ -193,6 +194,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { ) -> auth::Result> { let res = match self { Self::Console(api, user_info) => { + dbg!("authenticating..."); info!( user = &*user_info.user, project = user_info.endpoint(), @@ -204,7 +206,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { } }; - info!("user successfully authenticated"); + dbg!("user successfully authenticated"); Ok(res) } } diff --git a/proxy/src/auth_proxy/backend/classic.rs b/proxy/src/auth_proxy/backend/classic.rs index e6380e51ac..0539d9cf19 100644 --- a/proxy/src/auth_proxy/backend/classic.rs +++ b/proxy/src/auth_proxy/backend/classic.rs @@ -23,7 +23,7 @@ pub(super) async fn authenticate( return Err(auth::AuthError::bad_auth_method("MD5")); } AuthSecret::Scram(secret) => { - info!("auth endpoint chooses SCRAM"); + dbg!("auth endpoint chooses SCRAM"); let scram = auth_proxy::Scram(&secret); let auth_outcome = tokio::time::timeout( diff --git a/proxy/src/auth_proxy/flow.rs b/proxy/src/auth_proxy/flow.rs index d6fb4210ba..4c71fed189 100644 --- a/proxy/src/auth_proxy/flow.rs +++ b/proxy/src/auth_proxy/flow.rs @@ -81,6 +81,7 @@ impl<'a> AuthFlow<'a, Begin> { /// Move to the next step by sending auth method's name & params to client. pub(crate) async fn begin(self, method: M) -> io::Result> { + dbg!("sending auth begin message"); self.stream .write_message(&method.first_message(self.tls_server_end_point.supported())) .await?; diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 55a7c37bde..7998b495b0 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -33,7 +33,7 @@ struct ProxyCliArgs { #[clap( short, long, - default_value = "http://localhost:3000/authenticate_proxy_request/" + default_value = "http://localhost:3000/authenticate_proxy_request" )] auth_endpoint: String, /// timeout for the TLS handshake @@ -170,7 +170,7 @@ async fn main() { rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, }; - let config = Box::leak(Box::new(AuthProxyConfig { backend, auth })); + let config = &*Box::leak(Box::new(AuthProxyConfig { backend, auth })); loop { select! { @@ -183,7 +183,9 @@ async fn main() { } stream = conn.accept_bi() => { let (send, recv) = stream.unwrap(); - tasks.spawn(handle_stream(config, send, recv)); + tasks.spawn(async move { + handle_stream(config, send, recv).await.inspect_err(|e| println!("err {e:?}")) + }); } } } diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index 0cd2c727b9..d32863de35 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -6,7 +6,7 @@ use std::{ }; use anyhow::{anyhow, bail, Context, Result}; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use indexmap::IndexMap; use itertools::Itertools; @@ -446,9 +446,11 @@ impl PglbConn { }; match msg? { PglbMessage::Postgres(mut payload) => { + println!("msg {payload:?}"); let Some(msg) = PgRawMessage::decode(&mut payload, false)? else { bail!("auth proxy sent invalid message"); }; + println!("parsed msg {msg:?}"); client_stream.send(msg).await?; } PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => { @@ -649,10 +651,11 @@ impl tokio_util::codec::Decoder for PgRawCodec { } } +#[derive(Debug)] pub enum PgRawMessage { SslRequest, - Start(Vec), - Generic { tag: u8, payload: Vec }, + Start(Bytes), + Generic { tag: u8, payload: Bytes }, } impl PgRawMessage { @@ -663,13 +666,13 @@ impl PgRawMessage { dst.put_u32(80877103); } Self::Start(payload) => { - dst.put_u32(payload.len() as u32 + 4); - dst.put_slice(&payload); + // dst.put_u32(payload.len() as u32 + 4); + dst.put_slice(payload); } Self::Generic { tag, payload } => { - dst.put_u8(*tag); - dst.put_u32(payload.len() as u32 + 4); - dst.put_slice(&payload); + // dst.put_u8(*tag); + // dst.put_u32(payload.len() as u32 + 4); + dst.put_slice(payload); } } Ok(()) @@ -689,7 +692,7 @@ impl PgRawMessage { if length == 8 && src[4..8] == 80877103u32.to_be_bytes() { Ok(Some(PgRawMessage::SslRequest)) } else { - Ok(Some(PgRawMessage::Start(src.split_to(length).to_vec()))) + Ok(Some(PgRawMessage::Start(src.split_to(length).freeze()))) } } else { if src.remaining() < 5 { @@ -697,14 +700,14 @@ impl PgRawMessage { return Ok(None); } let tag = src[0]; - let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize - 1; + let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1; if src.remaining() < length { src.reserve(length - src.remaining()); return Ok(None); } Ok(Some(PgRawMessage::Generic { tag, - payload: src.split_to(length).to_vec(), + payload: src.split_to(length).freeze(), })) } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 146b4bf369..0026821f75 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -27,6 +27,7 @@ use tokio::io::join; use tokio_postgres::config::AuthKeys; use tokio_util::codec::Framed; +use crate::auth::backend::ComputeCredentialKeys; use crate::auth::backend::ComputeCredentials; use crate::auth_proxy::AuthProxyStream; use crate::auth_proxy::TLS_SERVER_END_POINT; @@ -55,6 +56,7 @@ use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; +use std::net::IpAddr; use std::net::SocketAddr; use std::sync::Arc; use thiserror::Error; @@ -493,14 +495,13 @@ pub async fn handle_stream( // wake the compute let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?; - let socket: SocketAddr = node_info.config.get_host()?.parse()?; println!("woke compute"); + let addr: IpAddr = node_info.config.get_host()?.parse()?; + let socket = SocketAddr::new(addr, node_info.config.get_ports()[0]); + // tell pglb that the compute is up - stream - .write_message(&pq_proto::BeMessage::AuthenticationOk) - .await?; stream .send(PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket, @@ -512,7 +513,7 @@ pub async fn handle_stream( frontend::startup_message(params.iter(), &mut buf)?; stream.send(PglbMessage::Postgres(buf.split())).await?; - auth_with_compute(&mut stream, &node_info).await?; + auth_with_compute(&mut stream, user_info.get_keys()).await?; stream .send(PglbMessage::Control(PglbControlMessage::ComputeEstablish)) @@ -527,6 +528,8 @@ async fn auth_with_user( conn_info: &ConnectionInitiatedPayload, params: &StartupMessageParams, ) -> anyhow::Result> { + dbg!("auth..."); + // Extract credentials which we're going to use for auth. let user_info = auth::ComputeUserInfoMaybeEndpoint { user: params.get("user").context("missing user")?.into(), @@ -537,6 +540,8 @@ async fn auth_with_user( options: NeonOptions::parse_params(params), }; + dbg!("parsed used info"); + // authenticate the user let user_info = config.backend.as_ref().map(|()| user_info); let res = TLS_SERVER_END_POINT @@ -558,12 +563,11 @@ async fn auth_with_user( async fn auth_with_compute( stream: &mut AuthProxyStream, - node_info: &NodeInfo, + keys: &ComputeCredentialKeys, ) -> anyhow::Result<()> { - let AuthKeys::ScramSha256(scram_keys) = node_info - .config - .get_auth_keys() - .context("missing auth keys")?; + let ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(scram_keys)) = keys else { + bail!("missing keys"); + }; // compute offers sasl stream @@ -576,7 +580,7 @@ async fn auth_with_compute( let mut buf = BytesMut::new(); // send auth message - let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported()); + let mut scram = ScramSha256::new_with_keys(*scram_keys, ChannelBinding::unsupported()); frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?; stream.send(PglbMessage::Postgres(buf.split())).await?;