abstract out auth

This commit is contained in:
Conrad Ludgate
2024-09-13 08:44:04 +01:00
parent 9131d0463d
commit d698a50984
2 changed files with 113 additions and 62 deletions

View File

@@ -18,7 +18,6 @@ use futures::TryStreamExt;
use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ChannelBinding;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use pq_proto::FeStartupPacket;
@@ -28,9 +27,12 @@ use tokio::io::join;
use tokio_postgres::config::AuthKeys;
use tokio_util::codec::Framed;
use crate::auth::backend::ComputeCredentials;
use crate::auth_proxy::AuthProxyStream;
use crate::auth_proxy::TLS_SERVER_END_POINT;
use crate::console::NodeInfo;
use crate::stream::AuthProxyStreamExt;
use crate::ConnectionInitiatedPayload;
use crate::PglbControlMessage;
use crate::PglbMessage;
use crate::{
@@ -470,7 +472,7 @@ pub async fn handle_stream(
// recv connection metadata
let first_msg = stream.try_next().await?;
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg
else {
panic!("invalid first msg")
};
@@ -481,31 +483,7 @@ pub async fn handle_stream(
panic!("invalid startup message")
};
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").context("missing user")?.into(),
endpoint_id: first_msg
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(&params),
};
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
.scope(
first_msg.tls_server_end_point,
user_info.authenticate(&mut stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
let user_info = auth_with_user(&mut stream, config, &conn_info, &params).await?;
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
@@ -526,41 +504,7 @@ pub async fn handle_stream(
frontend::startup_message(params.iter(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
// start auth with compute
let AuthKeys::ScramSha256(scram_keys) = node_info
.config
.get_auth_keys()
.context("missing auth keys")?;
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
// TODO: "negotiate" the auth mechanism
// send auth message
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
bail!("invalid message");
};
match backend::Message::parse(&mut buf2)? {
Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?,
_ => bail!("invalid message"),
};
frontend::sasl_response(scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
bail!("invalid message");
};
match backend::Message::parse(&mut buf2)? {
Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?,
_ => bail!("invalid message"),
};
auth_with_compute(&mut stream, &node_info).await?;
stream
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
@@ -568,3 +512,92 @@ pub async fn handle_stream(
Ok(())
}
async fn auth_with_user(
stream: &mut AuthProxyStream,
config: &'static AuthProxyConfig,
conn_info: &ConnectionInitiatedPayload,
params: &StartupMessageParams,
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").context("missing user")?.into(),
endpoint_id: conn_info
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(params),
};
// authenticate the user
let user_info = config.backend.as_ref().map(|()| user_info);
let res = TLS_SERVER_END_POINT
.scope(
conn_info.tls_server_end_point,
user_info.authenticate(stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
Ok(user_info)
}
async fn auth_with_compute(
stream: &mut AuthProxyStream,
node_info: &NodeInfo,
) -> anyhow::Result<()> {
let AuthKeys::ScramSha256(scram_keys) = node_info
.config
.get_auth_keys()
.context("missing auth keys")?;
// compute offers sasl
stream
.read_backend_message(|m| match m {
Message::AuthenticationSasl(_body) => Ok(()),
_ => bail!("invalid message"),
})
.await?;
let mut buf = BytesMut::new();
// send auth message
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let cont_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslContinue(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.update(cont_body.data()).await?;
frontend::sasl_response(scram.message(), &mut buf)?;
stream.send(PglbMessage::Postgres(buf.split())).await?;
let final_body = stream
.read_backend_message(|m| match m {
Message::AuthenticationSaslFinal(body) => Ok(body),
_ => bail!("invalid message"),
})
.await?;
scram.finish(final_body.data())?;
// wait for ok.
stream
.read_backend_message(|m| match m {
Message::AuthenticationOk => Ok(()),
_ => bail!("invalid message"),
})
.await?;
Ok(())
}

View File

@@ -3,9 +3,11 @@ use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::PglbMessage;
use anyhow::{bail, Context};
use bytes::BytesMut;
use futures::{SinkExt, TryStreamExt};
use postgres_protocol::message::backend;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -317,6 +319,11 @@ pub(crate) trait AuthProxyStreamExt {
async fn read_message(&mut self) -> io::Result<FeMessage>;
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T>;
}
impl AuthProxyStreamExt for AuthProxyStream {
@@ -417,4 +424,15 @@ impl AuthProxyStreamExt for AuthProxyStream {
)),
}
}
async fn read_backend_message<T>(
&mut self,
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
) -> anyhow::Result<T> {
let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else {
bail!("invalid message");
};
let message = backend::Message::parse(&mut buf)?.context("missing")?;
f(message)
}
}