add auth handshake to compute

This commit is contained in:
Conrad Ludgate
2024-09-12 21:16:27 +01:00
parent 3f66c12280
commit 76371e8452
3 changed files with 81 additions and 16 deletions

View File

@@ -353,7 +353,7 @@ impl tokio_util::codec::Decoder for PglbCodec {
match msg {
// postgres
0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))),
0 => Ok(Some(PglbMessage::Postgres(payload))),
// control
1 => {
if payload.is_empty() {
@@ -393,7 +393,7 @@ impl tokio_util::codec::Decoder for PglbCodec {
pub enum PglbMessage {
Control(PglbControlMessage),
Postgres(bytes::Bytes),
Postgres(bytes::BytesMut),
}
pub enum PglbControlMessage {

View File

@@ -7,15 +7,25 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use anyhow::bail;
use anyhow::Context;
use bytes::BytesMut;
use connect_compute::ComputeConnectBackend;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use futures::SinkExt;
use futures::TryStreamExt;
use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ChannelBinding;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use pq_proto::FeStartupPacket;
use quinn::RecvStream;
use quinn::SendStream;
use tokio::io::join;
use tokio_postgres::config::AuthKeys;
use tokio_util::codec::Framed;
use crate::auth_proxy::AuthProxyStream;
@@ -451,23 +461,29 @@ pub struct AuthProxyConfig {
pub auth: crate::config::AuthenticationConfig,
}
pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, recv: RecvStream) {
pub async fn handle_stream(
config: &'static AuthProxyConfig,
send: SendStream,
recv: RecvStream,
) -> anyhow::Result<()> {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
let first_msg = stream.try_next().await.unwrap();
// recv connection metadata
let first_msg = stream.try_next().await?;
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
else {
panic!("invalid first msg")
};
let startup = stream.read_startup_packet().await.unwrap();
// read startup packet
let startup = stream.read_startup_packet().await?;
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
panic!("invalid startup message")
};
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").unwrap().into(),
user: params.get("user").context("missing user")?.into(),
endpoint_id: first_msg
.server_name
.as_deref()
@@ -475,6 +491,7 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r
options: NeonOptions::parse_params(&params),
};
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
.scope(
@@ -482,24 +499,72 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r
user_info.authenticate(&mut stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await.unwrap();
return stream.throw_error(e).await?;
}
};
let node_info = user_info
.wake_compute(&RequestMonitoring::test())
.await
.unwrap();
let socket: SocketAddr = node_info.config.get_host().unwrap().parse().unwrap();
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
let socket: SocketAddr = node_info.config.get_host()?.parse()?;
// tell pglb that the compute is up
stream
.write_message(&pq_proto::BeMessage::AuthenticationOk)
.await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
socket,
}))
.await
.unwrap();
.await?;
// send startup message to compute
let mut buf = BytesMut::new();
frontend::startup_message(params.iter(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
// start auth with compute
let AuthKeys::ScramSha256(scram_keys) = node_info
.config
.get_auth_keys()
.context("missing auth keys")?;
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
// TODO: "negotiate" the auth mechanism
// send auth message
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
bail!("invalid message");
};
match backend::Message::parse(&mut buf2)? {
Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?,
_ => bail!("invalid message"),
};
frontend::sasl_response(scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
bail!("invalid message");
};
match backend::Message::parse(&mut buf2)? {
Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?,
_ => bail!("invalid message"),
};
stream
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
.await?;
Ok(())
}

View File

@@ -324,7 +324,7 @@ impl AuthProxyStreamExt for AuthProxyStream {
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
let mut b = BytesMut::new();
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.start_send_unpin(PglbMessage::Postgres(b.freeze()))
self.start_send_unpin(PglbMessage::Postgres(b))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}