diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1bb631bb59..9b50686b51 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -18,7 +18,6 @@ use futures::TryStreamExt; use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ChannelBinding; use postgres_protocol::authentication::sasl::ScramSha256; -use postgres_protocol::message::backend; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use pq_proto::FeStartupPacket; @@ -28,9 +27,12 @@ use tokio::io::join; use tokio_postgres::config::AuthKeys; use tokio_util::codec::Framed; +use crate::auth::backend::ComputeCredentials; use crate::auth_proxy::AuthProxyStream; use crate::auth_proxy::TLS_SERVER_END_POINT; +use crate::console::NodeInfo; use crate::stream::AuthProxyStreamExt; +use crate::ConnectionInitiatedPayload; use crate::PglbControlMessage; use crate::PglbMessage; use crate::{ @@ -470,7 +472,7 @@ pub async fn handle_stream( // recv connection metadata let first_msg = stream.try_next().await?; - let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg + let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg else { panic!("invalid first msg") }; @@ -481,31 +483,7 @@ pub async fn handle_stream( panic!("invalid startup message") }; - // Extract credentials which we're going to use for auth. - let user_info = auth::ComputeUserInfoMaybeEndpoint { - user: params.get("user").context("missing user")?.into(), - endpoint_id: first_msg - .server_name - .as_deref() - .map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()), - options: NeonOptions::parse_params(¶ms), - }; - - // authenticate the user - let user_info = config.backend.as_ref().map(|()| user_info); - let res = TLS_SERVER_END_POINT - .scope( - first_msg.tls_server_end_point, - user_info.authenticate(&mut stream, &config.auth), - ) - .await; - - let user_info = match res { - Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e).await?; - } - }; + let user_info = auth_with_user(&mut stream, config, &conn_info, ¶ms).await?; // wake the compute let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?; @@ -526,41 +504,7 @@ pub async fn handle_stream( frontend::startup_message(params.iter(), &mut buf)?; stream.send(PglbMessage::Postgres(buf.split())).await?; - // start auth with compute - - let AuthKeys::ScramSha256(scram_keys) = node_info - .config - .get_auth_keys() - .context("missing auth keys")?; - - let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported()); - - // TODO: "negotiate" the auth mechanism - - // send auth message - frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?; - stream.send(PglbMessage::Postgres(buf.split())).await?; - - let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else { - bail!("invalid message"); - }; - - match backend::Message::parse(&mut buf2)? { - Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?, - _ => bail!("invalid message"), - }; - - frontend::sasl_response(scram.message(), &mut buf)?; - stream.send(PglbMessage::Postgres(buf.split())).await?; - - let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else { - bail!("invalid message"); - }; - - match backend::Message::parse(&mut buf2)? { - Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?, - _ => bail!("invalid message"), - }; + auth_with_compute(&mut stream, &node_info).await?; stream .send(PglbMessage::Control(PglbControlMessage::ComputeEstablish)) @@ -568,3 +512,92 @@ pub async fn handle_stream( Ok(()) } + +async fn auth_with_user( + stream: &mut AuthProxyStream, + config: &'static AuthProxyConfig, + conn_info: &ConnectionInitiatedPayload, + params: &StartupMessageParams, +) -> anyhow::Result> { + // Extract credentials which we're going to use for auth. + let user_info = auth::ComputeUserInfoMaybeEndpoint { + user: params.get("user").context("missing user")?.into(), + endpoint_id: conn_info + .server_name + .as_deref() + .map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()), + options: NeonOptions::parse_params(params), + }; + + // authenticate the user + let user_info = config.backend.as_ref().map(|()| user_info); + let res = TLS_SERVER_END_POINT + .scope( + conn_info.tls_server_end_point, + user_info.authenticate(stream, &config.auth), + ) + .await; + + let user_info = match res { + Ok(auth_result) => auth_result, + Err(e) => { + return stream.throw_error(e).await?; + } + }; + + Ok(user_info) +} + +async fn auth_with_compute( + stream: &mut AuthProxyStream, + node_info: &NodeInfo, +) -> anyhow::Result<()> { + let AuthKeys::ScramSha256(scram_keys) = node_info + .config + .get_auth_keys() + .context("missing auth keys")?; + + // compute offers sasl + stream + .read_backend_message(|m| match m { + Message::AuthenticationSasl(_body) => Ok(()), + _ => bail!("invalid message"), + }) + .await?; + + let mut buf = BytesMut::new(); + + // send auth message + let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported()); + frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?; + stream.send(PglbMessage::Postgres(buf.split())).await?; + + let cont_body = stream + .read_backend_message(|m| match m { + Message::AuthenticationSaslContinue(body) => Ok(body), + _ => bail!("invalid message"), + }) + .await?; + scram.update(cont_body.data()).await?; + + frontend::sasl_response(scram.message(), &mut buf)?; + stream.send(PglbMessage::Postgres(buf.split())).await?; + + let final_body = stream + .read_backend_message(|m| match m { + Message::AuthenticationSaslFinal(body) => Ok(body), + _ => bail!("invalid message"), + }) + .await?; + scram.finish(final_body.data())?; + + // wait for ok. + stream + .read_backend_message(|m| match m { + Message::AuthenticationOk => Ok(()), + _ => bail!("invalid message"), + }) + .await?; + + Ok(()) +} diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 6607067e4b..e7a2e46e3a 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -3,9 +3,11 @@ use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; use crate::PglbMessage; +use anyhow::{bail, Context}; use bytes::BytesMut; use futures::{SinkExt, TryStreamExt}; +use postgres_protocol::message::backend; use pq_proto::framed::{ConnectionError, Framed}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; @@ -317,6 +319,11 @@ pub(crate) trait AuthProxyStreamExt { async fn read_message(&mut self) -> io::Result; async fn read_password_message(&mut self) -> io::Result; + + async fn read_backend_message( + &mut self, + f: impl FnOnce(backend::Message) -> anyhow::Result, + ) -> anyhow::Result; } impl AuthProxyStreamExt for AuthProxyStream { @@ -417,4 +424,15 @@ impl AuthProxyStreamExt for AuthProxyStream { )), } } + + async fn read_backend_message( + &mut self, + f: impl FnOnce(backend::Message) -> anyhow::Result, + ) -> anyhow::Result { + let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else { + bail!("invalid message"); + }; + let message = backend::Message::parse(&mut buf)?.context("missing")?; + f(message) + } }