mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-22 15:41:15 +00:00
add auth handshake to compute
This commit is contained in:
@@ -353,7 +353,7 @@ impl tokio_util::codec::Decoder for PglbCodec {
|
||||
|
||||
match msg {
|
||||
// postgres
|
||||
0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))),
|
||||
0 => Ok(Some(PglbMessage::Postgres(payload))),
|
||||
// control
|
||||
1 => {
|
||||
if payload.is_empty() {
|
||||
@@ -393,7 +393,7 @@ impl tokio_util::codec::Decoder for PglbCodec {
|
||||
|
||||
pub enum PglbMessage {
|
||||
Control(PglbControlMessage),
|
||||
Postgres(bytes::Bytes),
|
||||
Postgres(bytes::BytesMut),
|
||||
}
|
||||
|
||||
pub enum PglbControlMessage {
|
||||
|
||||
@@ -7,15 +7,25 @@ pub(crate) mod handshake;
|
||||
pub(crate) mod passthrough;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use connect_compute::ComputeConnectBackend;
|
||||
pub use copy_bidirectional::copy_bidirectional_client_compute;
|
||||
pub use copy_bidirectional::ErrorSource;
|
||||
use futures::SinkExt;
|
||||
use futures::TryStreamExt;
|
||||
use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::ChannelBinding;
|
||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use quinn::RecvStream;
|
||||
use quinn::SendStream;
|
||||
use tokio::io::join;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
use crate::auth_proxy::AuthProxyStream;
|
||||
@@ -451,23 +461,29 @@ pub struct AuthProxyConfig {
|
||||
pub auth: crate::config::AuthenticationConfig,
|
||||
}
|
||||
|
||||
pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, recv: RecvStream) {
|
||||
pub async fn handle_stream(
|
||||
config: &'static AuthProxyConfig,
|
||||
send: SendStream,
|
||||
recv: RecvStream,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
|
||||
|
||||
let first_msg = stream.try_next().await.unwrap();
|
||||
// recv connection metadata
|
||||
let first_msg = stream.try_next().await?;
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
|
||||
else {
|
||||
panic!("invalid first msg")
|
||||
};
|
||||
|
||||
let startup = stream.read_startup_packet().await.unwrap();
|
||||
// read startup packet
|
||||
let startup = stream.read_startup_packet().await?;
|
||||
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
|
||||
panic!("invalid startup message")
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = auth::ComputeUserInfoMaybeEndpoint {
|
||||
user: params.get("user").unwrap().into(),
|
||||
user: params.get("user").context("missing user")?.into(),
|
||||
endpoint_id: first_msg
|
||||
.server_name
|
||||
.as_deref()
|
||||
@@ -475,6 +491,7 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r
|
||||
options: NeonOptions::parse_params(¶ms),
|
||||
};
|
||||
|
||||
// authenticate the user
|
||||
let user_info = config.backend.as_ref().map(|()| user_info);
|
||||
let res = TLS_SERVER_END_POINT
|
||||
.scope(
|
||||
@@ -482,24 +499,72 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r
|
||||
user_info.authenticate(&mut stream, &config.auth),
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_info = match res {
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await.unwrap();
|
||||
return stream.throw_error(e).await?;
|
||||
}
|
||||
};
|
||||
|
||||
let node_info = user_info
|
||||
.wake_compute(&RequestMonitoring::test())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let socket: SocketAddr = node_info.config.get_host().unwrap().parse().unwrap();
|
||||
// wake the compute
|
||||
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
|
||||
let socket: SocketAddr = node_info.config.get_host()?.parse()?;
|
||||
|
||||
// tell pglb that the compute is up
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::AuthenticationOk)
|
||||
.await?;
|
||||
stream
|
||||
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
|
||||
socket,
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
// send startup message to compute
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(params.iter(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
// start auth with compute
|
||||
|
||||
let AuthKeys::ScramSha256(scram_keys) = node_info
|
||||
.config
|
||||
.get_auth_keys()
|
||||
.context("missing auth keys")?;
|
||||
|
||||
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
|
||||
|
||||
// TODO: "negotiate" the auth mechanism
|
||||
|
||||
// send auth message
|
||||
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
|
||||
match backend::Message::parse(&mut buf2)? {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?,
|
||||
_ => bail!("invalid message"),
|
||||
};
|
||||
|
||||
frontend::sasl_response(scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
|
||||
match backend::Message::parse(&mut buf2)? {
|
||||
Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?,
|
||||
_ => bail!("invalid message"),
|
||||
};
|
||||
|
||||
stream
|
||||
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -324,7 +324,7 @@ impl AuthProxyStreamExt for AuthProxyStream {
|
||||
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
let mut b = BytesMut::new();
|
||||
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.start_send_unpin(PglbMessage::Postgres(b.freeze()))
|
||||
self.start_send_unpin(PglbMessage::Postgres(b))
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user