diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index bc953b2217..256835cd15 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -353,7 +353,7 @@ impl tokio_util::codec::Decoder for PglbCodec { match msg { // postgres - 0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))), + 0 => Ok(Some(PglbMessage::Postgres(payload))), // control 1 => { if payload.is_empty() { @@ -393,7 +393,7 @@ impl tokio_util::codec::Decoder for PglbCodec { pub enum PglbMessage { Control(PglbControlMessage), - Postgres(bytes::Bytes), + Postgres(bytes::BytesMut), } pub enum PglbControlMessage { diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 0996c18efc..1bb631bb59 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -7,15 +7,25 @@ pub(crate) mod handshake; pub(crate) mod passthrough; pub(crate) mod retry; pub(crate) mod wake_compute; +use anyhow::bail; +use anyhow::Context; +use bytes::BytesMut; use connect_compute::ComputeConnectBackend; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; use futures::SinkExt; use futures::TryStreamExt; +use postgres_protocol::authentication::sasl; +use postgres_protocol::authentication::sasl::ChannelBinding; +use postgres_protocol::authentication::sasl::ScramSha256; +use postgres_protocol::message::backend; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; use pq_proto::FeStartupPacket; use quinn::RecvStream; use quinn::SendStream; use tokio::io::join; +use tokio_postgres::config::AuthKeys; use tokio_util::codec::Framed; use crate::auth_proxy::AuthProxyStream; @@ -451,23 +461,29 @@ pub struct AuthProxyConfig { pub auth: crate::config::AuthenticationConfig, } -pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, recv: RecvStream) { +pub async fn handle_stream( + config: &'static AuthProxyConfig, + send: SendStream, + recv: RecvStream, +) -> anyhow::Result<()> { let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec); - let first_msg = stream.try_next().await.unwrap(); + // recv connection metadata + let first_msg = stream.try_next().await?; let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg else { panic!("invalid first msg") }; - let startup = stream.read_startup_packet().await.unwrap(); + // read startup packet + let startup = stream.read_startup_packet().await?; let FeStartupPacket::StartupMessage { version: _, params } = startup else { panic!("invalid startup message") }; // Extract credentials which we're going to use for auth. let user_info = auth::ComputeUserInfoMaybeEndpoint { - user: params.get("user").unwrap().into(), + user: params.get("user").context("missing user")?.into(), endpoint_id: first_msg .server_name .as_deref() @@ -475,6 +491,7 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r options: NeonOptions::parse_params(¶ms), }; + // authenticate the user let user_info = config.backend.as_ref().map(|()| user_info); let res = TLS_SERVER_END_POINT .scope( @@ -482,24 +499,72 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r user_info.authenticate(&mut stream, &config.auth), ) .await; + let user_info = match res { Ok(auth_result) => auth_result, Err(e) => { - return stream.throw_error(e).await.unwrap(); + return stream.throw_error(e).await?; } }; - let node_info = user_info - .wake_compute(&RequestMonitoring::test()) - .await - .unwrap(); - - let socket: SocketAddr = node_info.config.get_host().unwrap().parse().unwrap(); + // wake the compute + let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?; + let socket: SocketAddr = node_info.config.get_host()?.parse()?; + // tell pglb that the compute is up + stream + .write_message(&pq_proto::BeMessage::AuthenticationOk) + .await?; stream .send(PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket, })) - .await - .unwrap(); + .await?; + + // send startup message to compute + let mut buf = BytesMut::new(); + frontend::startup_message(params.iter(), &mut buf)?; + stream.send(PglbMessage::Postgres(buf.split())).await?; + + // start auth with compute + + let AuthKeys::ScramSha256(scram_keys) = node_info + .config + .get_auth_keys() + .context("missing auth keys")?; + + let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported()); + + // TODO: "negotiate" the auth mechanism + + // send auth message + frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?; + stream.send(PglbMessage::Postgres(buf.split())).await?; + + let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else { + bail!("invalid message"); + }; + + match backend::Message::parse(&mut buf2)? { + Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?, + _ => bail!("invalid message"), + }; + + frontend::sasl_response(scram.message(), &mut buf)?; + stream.send(PglbMessage::Postgres(buf.split())).await?; + + let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else { + bail!("invalid message"); + }; + + match backend::Message::parse(&mut buf2)? { + Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?, + _ => bail!("invalid message"), + }; + + stream + .send(PglbMessage::Control(PglbControlMessage::ComputeEstablish)) + .await?; + + Ok(()) } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 4175177adb..6607067e4b 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -324,7 +324,7 @@ impl AuthProxyStreamExt for AuthProxyStream { fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { let mut b = BytesMut::new(); BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - self.start_send_unpin(PglbMessage::Postgres(b.freeze())) + self.start_send_unpin(PglbMessage::Postgres(b)) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; Ok(self) }